SocioCausaNet / modeling_joint_causal.py
rasoultilburg's picture
adding neural only mode for relation extraction
cf13840 verified
from __future__ import annotations
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel
from dataclasses import dataclass
try:
from .config import id2label_bio, id2label_rel, id2label_cls
except ImportError:
from config import id2label_bio, id2label_rel, id2label_cls
try:
from .configuration_joint_causal import JointCausalConfig
except ImportError:
from configuration_joint_causal import JointCausalConfig
# ---------------------------------------------------------------------------
# Type aliases & label maps
# ---------------------------------------------------------------------------
label2id_bio = {v: k for k, v in id2label_bio.items()}
label2id_rel = {v: k for k, v in id2label_rel.items()}
label2id_cls = {v: k for k, v in id2label_cls.items()}
# ---------------------------------------------------------------------------
# Main module
# ---------------------------------------------------------------------------
"""Joint Causal Extraction Model (softmax)
============================================================================
A PyTorch module for joint causal extraction using softmax decoding for BIO tagging.
The model supports class weights for handling imbalanced data.
```python
>>> model = JointCausalModel() # softmax-based model
"""
# ---------------------------------------------------------------------------
# Span dataclass
# ---------------------------------------------------------------------------
@dataclass
class Span:
role: str
start_tok: int
end_tok: int
text: str
is_virtual: bool = False
# ---------------------------------------------------------------------------
# Main module
# ---------------------------------------------------------------------------
class JointCausalModel(PreTrainedModel):
"""Encoder + three heads with **optional CRF** BIO decoder.
This model integrates a pre-trained transformer encoder with three distinct
heads for:
1. Classification (cls_head): Predicts a global label for the input.
2. BIO tagging (bio_head): Performs sequence tagging using BIO scheme.
Can operate with a CRF layer or standard softmax.
3. Relation extraction (rel_head): Identifies relations between entities
detected by the BIO tagging head.
"""
# Link the model to its config class, as shown in the tutorial.
config_class = JointCausalConfig
# ------------------------------------------------------------------
# constructor
# -----------------------------------------------------------
def __init__(self, config: JointCausalConfig):
"""Initializes the JointCausalModel.
Args:
encoder_name: Name of the pre-trained transformer model to use
(e.g., "bert-base-uncased").
num_cls_labels: Number of labels for the classification task.
num_bio_labels: Number of labels for the BIO tagging task.
num_rel_labels: Number of labels for the relation extraction task.
dropout: Dropout rate for regularization.
"""
super().__init__(config)
self.config = config
self.enc = AutoModel.from_pretrained(config.encoder_name)
self.hidden_size = self.enc.config.hidden_size
self.dropout = nn.Dropout(config.dropout)
self.layer_norm = nn.LayerNorm(self.hidden_size)
self.cls_head = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size // 2, config.num_cls_labels),
)
self.bio_head = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size // 2, config.num_bio_labels),
)
self.rel_head = nn.Sequential(
nn.Linear(self.hidden_size * 2, self.hidden_size),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(self.hidden_size // 2, config.num_rel_labels),
)
self._init_new_layer_weights()
def get_config_dict(self) -> Dict:
"""Returns the model's configuration as a dictionary."""
return {
"encoder_name": self.encoder_name,
"num_cls_labels": self.num_cls_labels,
"num_bio_labels": self.num_bio_labels,
"num_rel_labels": self.num_rel_labels,
"dropout": self.dropout_rate,
}
@classmethod
def from_config_dict(cls, config: Dict) -> "JointCausalModel":
"""Creates a JointCausalModel instance from a configuration dictionary."""
return cls(**config)
def _init_new_layer_weights(self):
"""Initializes the weights of the newly added linear layers.
Uses Xavier uniform initialization for weights and zeros for biases.
"""
for mod in [self.cls_head, self.bio_head, self.rel_head]:
for sub_module in mod.modules():
if isinstance(sub_module, nn.Linear):
nn.init.xavier_uniform_(sub_module.weight)
if sub_module.bias is not None:
nn.init.zeros_(sub_module.bias)
def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Encodes the input using the transformer model.
Args:
input_ids: Tensor of input token IDs.
attention_mask: Tensor indicating which tokens to attend to.
Returns:
Tensor of hidden states from the encoder, passed through dropout
and layer normalization.
"""
hidden_states = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
return self.layer_norm(self.dropout(hidden_states))
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
*,
bio_labels: torch.Tensor | None = None,
pair_batch: torch.Tensor | None = None,
cause_starts: torch.Tensor | None = None,
cause_ends: torch.Tensor | None = None,
effect_starts: torch.Tensor | None = None,
effect_ends: torch.Tensor | None = None,
) -> Dict[str, torch.Tensor | None]:
"""Performs a forward pass through the model.
Args:
input_ids: Tensor of input token IDs.
attention_mask: Tensor indicating which tokens to attend to.
bio_labels: Optional tensor of BIO labels for training.
pair_batch: Optional tensor indicating which hidden states to use
for relation extraction.
cause_starts: Optional tensor of start indices for cause spans.
cause_ends: Optional tensor of end indices for cause spans.
effect_starts: Optional tensor of start indices for effect spans.
effect_ends: Optional tensor of end indices for effect spans.
Returns:
A dictionary containing:
- "cls_logits": Logits for the classification task.
- "bio_emissions": Emissions from the BIO tagging head.
- "tag_loss": Loss for the BIO tagging task (if bio_labels provided).
- "rel_logits": Logits for the relation extraction task (if
relation extraction inputs provided).
"""
# Encode input
hidden = self.encode(input_ids, attention_mask)
# Classification head
cls_logits = self.cls_head(hidden[:, 0]) # Use [CLS] token representation
# BIO tagging head
emissions = self.bio_head(hidden)
tag_loss: Optional[torch.Tensor] = None
# Calculate BIO tagging loss if labels are provided
if bio_labels is not None:
# Softmax loss (typically handled by the training loop's loss function, e.g., CrossEntropyLoss)
# Here, we initialize it to 0.0 as a placeholder.
# The actual loss calculation for softmax would compare emissions with bio_labels.
tag_loss = torch.tensor(0.0, device=emissions.device)
# Relation extraction head
rel_logits: torch.Tensor | None = None
if pair_batch is not None and cause_starts is not None and cause_ends is not None \
and effect_starts is not None and effect_ends is not None:
# Select hidden states corresponding to the pairs for relation extraction
bio_states_for_rel = hidden[pair_batch]
seq_len_rel = bio_states_for_rel.size(1)
pos_rel = torch.arange(seq_len_rel, device=bio_states_for_rel.device).unsqueeze(0)
# Create masks for cause and effect spans
c_mask = ((cause_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= cause_ends.unsqueeze(1))).unsqueeze(2)
e_mask = ((effect_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= effect_ends.unsqueeze(1))).unsqueeze(2)
# Compute mean-pooled representations for cause and effect spans
c_vec = (bio_states_for_rel * c_mask).sum(1) / c_mask.sum(1).clamp(min=1) # Average pooling, clamp to avoid div by zero
e_vec = (bio_states_for_rel * e_mask).sum(1) / e_mask.sum(1).clamp(min=1) # Average pooling, clamp to avoid div by zero
# Concatenate cause and effect vectors and pass through relation head
rel_logits = self.rel_head(torch.cat([c_vec, e_vec], dim=1))
return {
"cls_logits": cls_logits,
"bio_emissions": emissions,
"tag_loss": tag_loss,
"rel_logits": rel_logits,
}
# ---------------------------------------------------------------------------
# Refactored prediction & post‑processing utilities for JointCausalModel
# ---------------------------------------------------------------------------
def predict(
self,
sents: List[str],
tokenizer=None,
*,
rel_mode: str = "neural_only",
rel_threshold: float = 0.8,
cause_decision: str = "cls+span",
) -> List[dict]:
"""End‑to‑end inference for causal sentence extraction (batched).
Args:
sents: List of input sentences for causal extraction.
tokenizer: Tokenizer instance for encoding sentences. If None, a default tokenizer is initialized.
rel_mode: Strategy for relation extraction. "auto" mode simplifies relations when spans are limited.
rel_threshold: Probability threshold for relation head to reduce spurious pairs.
cause_decision: Strategy for determining causality ('cls_only', 'span_only', or 'cls+span').
Returns:
List of dictionaries containing:
- "text": Original sentence.
- "causal": Boolean indicating if the sentence is causal.
- "relations": List of extracted causal relations.
"""
# ------------------------------------------------------------------
# 0. Tokeniser & device
# ------------------------------------------------------------------
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.encoder_name, use_fast=True)
device = next(self.parameters()).device
to_dev = lambda d: {k: v.to(device) for k, v in d.items()} # Move tensors to the model's device
outputs: List[dict] = []
# ------------------------------------------------------------------
# 1. Batch tokenize all sentences
# ------------------------------------------------------------------
enc = tokenizer(sents, return_tensors="pt", truncation=True, max_length=512, padding=True)
enc = to_dev(enc) # Ensure tensors are on the correct device
with torch.no_grad():
base = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
cls_logits_batch = base["cls_logits"] # Sentence-level classification logits
bio_emissions_batch = base["bio_emissions"] # BIO tagging emissions
input_ids_batch = enc["input_ids"] # Token IDs for each sentence
attention_mask_batch = enc["attention_mask"] # Attention mask for each sentence
batch_size = input_ids_batch.size(0)
for i in range(batch_size):
seq_len = attention_mask_batch[i].sum().item() # Determine the actual sequence length
input_ids = input_ids_batch[i][:seq_len] # Trim padding tokens
bio_emissions = bio_emissions_batch[i][:seq_len] # Trim emissions to sequence length
tokens = tokenizer.convert_ids_to_tokens(input_ids) # Convert token IDs to actual tokens
bio_ids = bio_emissions.argmax(-1).tolist() # Get predicted BIO label indices
bio_labels = [id2label_bio[j] for j in bio_ids] # Map indices to label names
# Apply BIO rules to clean up predictions
fixed_labels = self._apply_bio_rules(tokens, bio_labels)
spans = self._merge_spans(tokens, fixed_labels, tokenizer) # Merge spans based on cleaned labels
# Determine if the sentence is causal based on classification logits and spans
is_causal = self._decide_causal(cls_logits_batch[i], spans, cause_decision)
# ------------------------------------------------------------------
# 2. Relation extraction (per sentence, as before)
# ------------------------------------------------------------------
rels: List[dict] = []
pure_cause_spans = [s for s in spans if s.role == "C"] # Extract pure cause spans
pure_effect_spans = [s for s in spans if s.role == "E"] # Extract pure effect spans
ce_spans = [s for s in spans if s.role == "CE"] # Extract combined cause-effect spans
cause_spans = pure_cause_spans + ce_spans
effect_spans = pure_effect_spans + ce_spans
if cause_spans and effect_spans:
# Check for presence of pure causes/effects and combined spans
has_pure_causes = len(pure_cause_spans) > 0
has_pure_effects = len(pure_effect_spans) > 0
has_ce_spans = len(ce_spans) > 0
if has_ce_spans and not (has_pure_causes or has_pure_effects):
pass # Skip relation extraction if only combined spans exist
elif rel_mode == "auto" and (len(cause_spans) == 1 or len(effect_spans) == 1):
# Simplified relation extraction for single spans
if len(cause_spans) == 1:
for e in effect_spans:
if (cause_spans[0].text.lower() != e.text.lower() or
(cause_spans[0].role == "CE" and e.role != "CE")):
rels.append({"cause": cause_spans[0].text, "effect": e.text, "type": "Rel_CE"})
else:
for c in cause_spans:
if (c.text.lower() != effect_spans[0].text.lower() or
(c.role == "CE" and effect_spans[0].role != "CE")):
rels.append({"cause": c.text, "effect": effect_spans[0].text, "type": "Rel_CE"})
elif rel_mode == "neural_only":
# Always use the relation head for all valid pairs
pair_meta = []
for c in cause_spans:
for e in effect_spans:
if (not (c.start_tok == e.start_tok and c.end_tok == e.end_tok) or
(c.role == "CE" and e.role in {"C", "E"}) or
(c.role in {"C", "E"} and e.role == "CE")):
pair_meta.append((c, e))
if pair_meta:
# Prepare tensors for this sentence only
pair_batch = torch.zeros(len(pair_meta), dtype=torch.long, device=device)
cause_starts = torch.tensor([c.start_tok for c, _ in pair_meta], device=device)
cause_ends = torch.tensor([c.end_tok for c, _ in pair_meta], device=device)
effect_starts = torch.tensor([e.start_tok for _, e in pair_meta], device=device)
effect_ends = torch.tensor([e.end_tok for _, e in pair_meta], device=device)
rel_logits = self(
input_ids=input_ids.unsqueeze(0),
attention_mask=attention_mask_batch[i][:seq_len].unsqueeze(0),
pair_batch=pair_batch,
cause_starts=cause_starts,
cause_ends=cause_ends,
effect_starts=effect_starts,
effect_ends=effect_ends,
)["rel_logits"]
probs = torch.softmax(rel_logits, dim=-1)[:, 1].tolist() # Extract probabilities for relation type
for (c, e), p in zip(pair_meta, probs):
if p >= rel_threshold and c.text.lower() != e.text.lower():
rels.append({"cause": c.text, "effect": e.text, "type": "Rel_CE"})
else:
# Full relation extraction for multiple spans
pair_meta = []
for c in cause_spans:
for e in effect_spans:
if (not (c.start_tok == e.start_tok and c.end_tok == e.end_tok) or
(c.role == "CE" and e.role in {"C", "E"}) or
(c.role in {"C", "E"} and e.role == "CE")):
pair_meta.append((c, e))
if pair_meta:
# Prepare tensors for this sentence only
pair_batch = torch.zeros(len(pair_meta), dtype=torch.long, device=device)
cause_starts = torch.tensor([c.start_tok for c, _ in pair_meta], device=device)
cause_ends = torch.tensor([c.end_tok for c, _ in pair_meta], device=device)
effect_starts = torch.tensor([e.start_tok for _, e in pair_meta], device=device)
effect_ends = torch.tensor([e.end_tok for _, e in pair_meta], device=device)
rel_logits = self(
input_ids=input_ids.unsqueeze(0),
attention_mask=attention_mask_batch[i][:seq_len].unsqueeze(0),
pair_batch=pair_batch,
cause_starts=cause_starts,
cause_ends=cause_ends,
effect_starts=effect_starts,
effect_ends=effect_ends,
)["rel_logits"]
probs = torch.softmax(rel_logits, dim=-1)[:, 1].tolist() # Extract probabilities for relation type
for (c, e), p in zip(pair_meta, probs):
if p >= rel_threshold and c.text.lower() != e.text.lower():
rels.append({"cause": c.text, "effect": e.text, "type": "Rel_CE"})
# Remove duplicate relations
seen = set()
uniq = []
for r in rels:
key = (r["cause"].lower(), r["effect"].lower())
if key not in seen:
seen.add(key)
uniq.append(r)
rels = uniq
# If the sentence is predicted as non-causal, ensure no spans or relations are returned
if not is_causal:
outputs.append({
"text": sents[i],
"causal": is_causal,
"relations": [], # Empty relations
"spans": [], # Empty spans
})
else:
outputs.append({
"text": sents[i],
"causal": is_causal,
"relations": rels,
})
return outputs
# ------------------------------------------------------------------
# BIO utilities
# ------------------------------------------------------------------
@staticmethod
def _apply_bio_rules(tok: List[str], lab: List[str]) -> List[str]:
"""Light‑touch BIO sanitiser that fixes **intra‑span role clashes** and
common WordPiece artefacts while deferring to model probabilities.
Added rule (R‑6)
----------------
When a contiguous non‑O block mixes **C** and **E** roles (e.g.
``B‑C I‑C I‑E I‑C``) we collapse the entire block to the *majority*
role (ties prefer **C**). Only the first token keeps the ``B‑`` prefix.
"""
n = len(tok)
out = lab.copy()
# R‑1 propagate to ## -------------------------------------------------
for i in range(1, n):
if tok[i].startswith("##") and out[i] == "O" and out[i-1] != "O":
role = out[i-1].split("-")[-1]
out[i] = f"I-{role}"
# R‑2 stray I‑tags → B ----------------------------------------------
for i in range(n):
if out[i].startswith("I-") and (i == 0 or out[i-1] == "O"):
out[i] = out[i].replace("I-", "B-", 1)
# R‑3 merge adjacent B blocks of same role ---------------------------
for i in range(1, n):
if out[i].startswith("B-") and out[i-1] != "O":
role_prev = out[i-1].split("-")[-1]
role_curr = out[i].split("-")[-1]
if role_prev == role_curr:
out[i] = out[i].replace("B-", "I-", 1)
# R‑4 (removed): We no longer force punctuation tokens to O
# This keeps apostrophes/hyphens inside spans when the model labels them. # R‑5 CE disambiguation - only convert CE if no other roles present
roles_present = {tag.split("-")[-1] for tag in out if tag != "O"}
if "CE" in roles_present and "C" not in roles_present and "E" not in roles_present:
# Only CE tags present - convert all to C (arbitrary choice)
for i, tag in enumerate(out):
if tag.endswith("CE"):
out[i] = tag[:-2] + "C"
# R‑6 intra‑span role clash fix - preserve CE spans when meaningful
i = 0
while i < n:
if out[i] == "O":
i += 1
continue
start = i
role_counts = {"C": 0, "E": 0, "CE": 0}
has_mixed_roles = False
# Count roles in this span and check for mixing
while i < n and out[i] != "O" and not (i > start and out[i].startswith("B-")):
role = out[i].split("-")[-1]
role_counts[role] += 1
i += 1
# Check if span has mixed C/E roles (not including CE)
non_ce_roles = set()
j = start
while j < i:
role = out[j].split("-")[-1]
if role in {"C", "E"}:
non_ce_roles.add(role)
j += 1
if len(non_ce_roles) > 1:
# Mixed C and E tags - resolve to majority
maj = "C" if role_counts["C"] >= role_counts["E"] else "E"
j = start
first = True
while j < i:
out[j] = ("B-" if first else "I-") + maj
first = False
j += 1
elif role_counts["CE"] > 0 and len(non_ce_roles) == 0:
# Pure CE span - keep as CE
j = start
first = True
while j < i:
out[j] = ("B-" if first else "I-") + "CE"
first = False
j += 1
elif role_counts["CE"] > 0 and len(non_ce_roles) == 1:
# CE mixed with single pure role - check if CE is meaningful
# If we have other pure spans of different types, keep CE
other_roles = {tag.split("-")[-1] for tag in out if tag != "O"}
pure_role = list(non_ce_roles)[0]
if (pure_role == "C" and "E" in other_roles) or (pure_role == "E" and "C" in other_roles):
# CE is meaningful - keep it
j = start
first = True
while j < i:
out[j] = ("B-" if first else "I-") + "CE"
first = False
j += 1
else:
# CE not meaningful - convert to pure role
j = start
first = True
while j < i:
out[j] = ("B-" if first else "I-") + pure_role
first = False
j += 1
# R‑7 connector & punctuation bridge ----------------------------------
CONNECT = {"of", "to", "with", "for", "and", "or", "but", "in"}
for k in range(1, n - 1):
left_role = out[k - 1].split("-")[-1] if out[k - 1] != "O" else None
right_role = out[k + 1].split("-")[-1] if out[k + 1] != "O" else None
if not left_role or left_role != right_role:
continue
# 7a: connector word originally tagged O
if out[k] == "O" and tok[k].lower() in CONNECT:
out[k] = "I-" + left_role
# 7b: single‑char punctuation / hyphen / apostrophe bridge
elif out[k] == "O" and len(tok[k]) == 1 and not tok[k].isalnum():
out[k] = "I-" + left_role # 7c: mis‑role single token sandwiched by same role
elif out[k].startswith("I-") and out[k].split("-")[-1] != left_role:
out[k] = "I-" + left_role
# R‑8 gap‑tolerant B‑tag merging ------------------------------------
# Merge B- tags of the same type separated by small gaps (≤1 O tokens)
# This reduces span fragmentation like "B-E O B-E" -> "B-E I-E I-E"
b_positions = {}
for i, label in enumerate(out):
if label.startswith("B-"):
role = label.split("-")[1]
if role not in b_positions:
b_positions[role] = []
b_positions[role].append(i)
for role, positions in b_positions.items():
if len(positions) < 2:
continue
# Group positions that are close together (gap ≤ 1)
groups = []
current_group = [positions[0]]
for i in range(1, len(positions)):
prev_pos = positions[i-1]
curr_pos = positions[i]
gap_size = curr_pos - prev_pos - 1
if gap_size <= 1: # Allow gaps of 0 or 1 O tokens
gap_labels = out[prev_pos + 1:curr_pos]
if all(label == "O" for label in gap_labels):
current_group.append(curr_pos)
else:
groups.append(current_group)
current_group = [curr_pos]
else:
groups.append(current_group)
current_group = [curr_pos]
groups.append(current_group)
# Merge groups with multiple B- tags
for group in groups:
if len(group) > 1:
first_pos = group[0]
last_pos = group[-1]
for pos in range(first_pos + 1, last_pos + 1):
if pos in group[1:]: # B- tag to convert
out[pos] = f"I-{role}"
elif out[pos] == "O": # Fill gap
out[pos] = f"I-{role}"
return out
# ------------------------------------------------------------------
@staticmethod
def _merge_spans(tok: List[str], lab: List[str], tokenizer) -> List["Span"]:
"""Turn cleaned BIO labels into Span objects.
Policy:
1. **First pass** – assemble raw spans, letting them bridge a single
connector (of, to, with, for, and, or, but, in).
2. **Trim** leading/trailing connectors & punctuation.
3. **Normalise** hyphen spacing & strip quotes.
4. **Role‑wise pruning** – if a role has ≥1 span with *≥2 words*, drop
*all* its 1‑word spans. This removes stray nouns like "choices"
while preserving them when they are the *only* cause/effect.
"""
CONNECT = {"of", "to", "with", "for", "and", "or", "but", "in"}
spans: List[Span] = []
i, n = 0, len(tok)
while i < n:
if lab[i] == "O":
i += 1; continue
role = lab[i].split("-")[-1]
s = i
i += 1
while i < n:
if lab[i].startswith("I-"):
i += 1; continue
if tok[i].lower() in CONNECT and lab[i] == "O" and i+1 < n and lab[i+1].startswith("I-"):
i += 1; continue
break
e = i - 1
text = tokenizer.convert_tokens_to_string(tok[s:e+1])
# basic cleanup
text = text.replace(" - ", "-").replace(" -", "-").replace("- ", "-")
text = text.strip("\"'”’““”")
words = text.split()
while words and words[0].lower() in CONNECT:
words.pop(0)
while words and words[-1].lower() in CONNECT:
words.pop()
if not words:
continue
clean_text = " ".join(words)
spans.append(Span(role, s, e, clean_text)) # role‑wise pruning --------------------------------------------------
from collections import defaultdict, OrderedDict
import re
by_role = defaultdict(list)
for sp in spans:
by_role[sp.role].append(sp)
final: List[Span] = []
for role, group in by_role.items():
has_multi = any((g.end_tok - g.start_tok) >= 1 for g in group)
for sp in group:
single_tok = (sp.end_tok - sp.start_tok) == 0
# Only remove single-token spans if they look like artifacts
# Keep all meaningful single-token spans like "depression", "cancer", etc.
if single_tok:
# Check if the span text looks like a meaningful entity
is_meaningful = (
len(sp.text) > 2 and # Longer than 2 characters
sp.text.isalpha() and # Only alphabetic characters
not sp.text.lower() in {"this", "that", "it", "they", "them", "he", "she", "we", "i", "you"} # Not pronouns
)
if not is_meaningful and has_multi:
# Only skip single-token spans that seem like artifacts when multi-token spans exist
if role == "C" or role == "E":
continue
final.append(sp)
final.sort(key=lambda s: s.start_tok)
# second pass: merge over *pure punctuation* gaps only -----------------
merged: List[Span] = []
def is_punct(tok):
return len(tok) == 1 and not tok.isalnum()
for sp in final:
if merged and sp.role == merged[-1].role:
gap_tokens = tok[merged[-1].end_tok + 1 : sp.start_tok]
if gap_tokens and all(is_punct(t) for t in gap_tokens):
# safe to merge across punctuation (e.g., apostrophe or hyphen)
combined_text = tokenizer.convert_tokens_to_string(tok[merged[-1].start_tok: sp.end_tok + 1]).strip("\"'”’““”")
merged[-1] = Span(sp.role, merged[-1].start_tok, sp.end_tok, combined_text)
continue
merged.append(sp)
return merged
def _decide_causal(self, cls_logits, spans, cause_decision):
"""Determine if a sentence is causal based on classification logits and spans.
Args:
cls_logits: Tensor of classification logits
spans: List of extracted spans
cause_decision: Strategy for determining causality ('cls_only', 'span_only', or 'cls+span')
Returns:
bool: True if the sentence is determined to be causal
"""
prob_causal = torch.softmax(cls_logits, dim=-1)[1].item()
# Check for presence of both cause and effect spans (CE spans count as both)
has_cause_spans = any(x.role in ("C", "CE") for x in spans)
has_effect_spans = any(x.role in ("E", "CE") for x in spans)
has_both_spans = has_cause_spans and has_effect_spans
if cause_decision == "cls_only":
return prob_causal >= 0.5
elif cause_decision == "span_only":
return has_both_spans
else: # "cls+span" - default behavior
return prob_causal >= 0.5 and has_both_spans