| | from __future__ import annotations
|
| | from typing import Dict, List, Optional
|
| | import torch
|
| | import torch.nn as nn
|
| | from transformers import AutoModel, PreTrainedModel
|
| | from dataclasses import dataclass
|
| | try:
|
| | from .config import id2label_bio, id2label_rel, id2label_cls
|
| | except ImportError:
|
| | from config import id2label_bio, id2label_rel, id2label_cls
|
| |
|
| | try:
|
| | from .configuration_joint_causal import JointCausalConfig
|
| | except ImportError:
|
| | from configuration_joint_causal import JointCausalConfig
|
| |
|
| |
|
| |
|
| |
|
| | label2id_bio = {v: k for k, v in id2label_bio.items()}
|
| | label2id_rel = {v: k for k, v in id2label_rel.items()}
|
| | label2id_cls = {v: k for k, v in id2label_cls.items()}
|
| |
|
| |
|
| |
|
| |
|
| | """Joint Causal Extraction Model (softmax)
|
| | ============================================================================
|
| |
|
| | A PyTorch module for joint causal extraction using softmax decoding for BIO tagging.
|
| | The model supports class weights for handling imbalanced data.
|
| |
|
| | ```python
|
| | >>> model = JointCausalModel() # softmax-based model
|
| | """
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | @dataclass
|
| | class Span:
|
| | role: str
|
| | start_tok: int
|
| | end_tok: int
|
| | text: str
|
| | is_virtual: bool = False
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class JointCausalModel(PreTrainedModel):
|
| |
|
| | """Encoder + three heads with **optional CRF** BIO decoder.
|
| |
|
| | This model integrates a pre-trained transformer encoder with three distinct
|
| | heads for:
|
| | 1. Classification (cls_head): Predicts a global label for the input.
|
| | 2. BIO tagging (bio_head): Performs sequence tagging using BIO scheme.
|
| | Can operate with a CRF layer or standard softmax.
|
| | 3. Relation extraction (rel_head): Identifies relations between entities
|
| | detected by the BIO tagging head.
|
| | """
|
| |
|
| | config_class = JointCausalConfig
|
| |
|
| |
|
| |
|
| |
|
| | def __init__(self, config: JointCausalConfig):
|
| |
|
| | """Initializes the JointCausalModel.
|
| |
|
| | Args:
|
| | encoder_name: Name of the pre-trained transformer model to use
|
| | (e.g., "bert-base-uncased").
|
| | num_cls_labels: Number of labels for the classification task.
|
| | num_bio_labels: Number of labels for the BIO tagging task.
|
| | num_rel_labels: Number of labels for the relation extraction task.
|
| | dropout: Dropout rate for regularization.
|
| | """
|
| |
|
| | super().__init__(config)
|
| | self.config = config
|
| |
|
| | self.enc = AutoModel.from_pretrained(config.encoder_name)
|
| | self.hidden_size = self.enc.config.hidden_size
|
| | self.dropout = nn.Dropout(config.dropout)
|
| | self.layer_norm = nn.LayerNorm(self.hidden_size)
|
| |
|
| |
|
| |
|
| | self.cls_head = nn.Sequential(
|
| | nn.Linear(self.hidden_size, self.hidden_size // 2),
|
| | nn.ReLU(),
|
| | nn.Dropout(config.dropout),
|
| | nn.Linear(self.hidden_size // 2, config.num_cls_labels),
|
| | )
|
| | self.bio_head = nn.Sequential(
|
| | nn.Linear(self.hidden_size, self.hidden_size),
|
| | nn.ReLU(),
|
| | nn.Dropout(config.dropout),
|
| | nn.Linear(self.hidden_size, self.hidden_size // 2),
|
| | nn.ReLU(),
|
| | nn.Dropout(config.dropout),
|
| | nn.Linear(self.hidden_size // 2, config.num_bio_labels),
|
| | )
|
| | self.rel_head = nn.Sequential(
|
| | nn.Linear(self.hidden_size * 2, self.hidden_size),
|
| | nn.ReLU(),
|
| | nn.Dropout(config.dropout),
|
| | nn.Linear(self.hidden_size, self.hidden_size // 2),
|
| | nn.ReLU(),
|
| | nn.Dropout(config.dropout),
|
| | nn.Linear(self.hidden_size // 2, config.num_rel_labels),
|
| | )
|
| | self._init_new_layer_weights()
|
| |
|
| | def get_config_dict(self) -> Dict:
|
| | """Returns the model's configuration as a dictionary."""
|
| | return {
|
| | "encoder_name": self.encoder_name,
|
| | "num_cls_labels": self.num_cls_labels,
|
| | "num_bio_labels": self.num_bio_labels,
|
| | "num_rel_labels": self.num_rel_labels,
|
| | "dropout": self.dropout_rate,
|
| | }
|
| |
|
| | @classmethod
|
| | def from_config_dict(cls, config: Dict) -> "JointCausalModel":
|
| | """Creates a JointCausalModel instance from a configuration dictionary."""
|
| | return cls(**config)
|
| |
|
| | def _init_new_layer_weights(self):
|
| | """Initializes the weights of the newly added linear layers.
|
| |
|
| | Uses Xavier uniform initialization for weights and zeros for biases.
|
| | """
|
| | for mod in [self.cls_head, self.bio_head, self.rel_head]:
|
| | for sub_module in mod.modules():
|
| | if isinstance(sub_module, nn.Linear):
|
| | nn.init.xavier_uniform_(sub_module.weight)
|
| | if sub_module.bias is not None:
|
| | nn.init.zeros_(sub_module.bias)
|
| |
|
| | def encode(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| | """Encodes the input using the transformer model.
|
| |
|
| | Args:
|
| | input_ids: Tensor of input token IDs.
|
| | attention_mask: Tensor indicating which tokens to attend to.
|
| |
|
| | Returns:
|
| | Tensor of hidden states from the encoder, passed through dropout
|
| | and layer normalization.
|
| | """
|
| | hidden_states = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
|
| | return self.layer_norm(self.dropout(hidden_states))
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids: torch.Tensor,
|
| | attention_mask: torch.Tensor,
|
| | *,
|
| | bio_labels: torch.Tensor | None = None,
|
| | pair_batch: torch.Tensor | None = None,
|
| | cause_starts: torch.Tensor | None = None,
|
| | cause_ends: torch.Tensor | None = None,
|
| | effect_starts: torch.Tensor | None = None,
|
| | effect_ends: torch.Tensor | None = None,
|
| | ) -> Dict[str, torch.Tensor | None]:
|
| | """Performs a forward pass through the model.
|
| |
|
| | Args:
|
| | input_ids: Tensor of input token IDs.
|
| | attention_mask: Tensor indicating which tokens to attend to.
|
| | bio_labels: Optional tensor of BIO labels for training.
|
| | pair_batch: Optional tensor indicating which hidden states to use
|
| | for relation extraction.
|
| | cause_starts: Optional tensor of start indices for cause spans.
|
| | cause_ends: Optional tensor of end indices for cause spans.
|
| | effect_starts: Optional tensor of start indices for effect spans.
|
| | effect_ends: Optional tensor of end indices for effect spans.
|
| |
|
| | Returns:
|
| | A dictionary containing:
|
| | - "cls_logits": Logits for the classification task.
|
| | - "bio_emissions": Emissions from the BIO tagging head.
|
| | - "tag_loss": Loss for the BIO tagging task (if bio_labels provided).
|
| | - "rel_logits": Logits for the relation extraction task (if
|
| | relation extraction inputs provided).
|
| | """
|
| |
|
| | hidden = self.encode(input_ids, attention_mask)
|
| |
|
| |
|
| | cls_logits = self.cls_head(hidden[:, 0])
|
| |
|
| |
|
| | emissions = self.bio_head(hidden)
|
| | tag_loss: Optional[torch.Tensor] = None
|
| |
|
| |
|
| | if bio_labels is not None:
|
| |
|
| |
|
| |
|
| | tag_loss = torch.tensor(0.0, device=emissions.device)
|
| |
|
| |
|
| | rel_logits: torch.Tensor | None = None
|
| | if pair_batch is not None and cause_starts is not None and cause_ends is not None \
|
| | and effect_starts is not None and effect_ends is not None:
|
| |
|
| | bio_states_for_rel = hidden[pair_batch]
|
| | seq_len_rel = bio_states_for_rel.size(1)
|
| | pos_rel = torch.arange(seq_len_rel, device=bio_states_for_rel.device).unsqueeze(0)
|
| |
|
| |
|
| | c_mask = ((cause_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= cause_ends.unsqueeze(1))).unsqueeze(2)
|
| | e_mask = ((effect_starts.unsqueeze(1) <= pos_rel) & (pos_rel <= effect_ends.unsqueeze(1))).unsqueeze(2)
|
| |
|
| |
|
| | c_vec = (bio_states_for_rel * c_mask).sum(1) / c_mask.sum(1).clamp(min=1)
|
| | e_vec = (bio_states_for_rel * e_mask).sum(1) / e_mask.sum(1).clamp(min=1)
|
| |
|
| |
|
| | rel_logits = self.rel_head(torch.cat([c_vec, e_vec], dim=1))
|
| |
|
| | return {
|
| | "cls_logits": cls_logits,
|
| | "bio_emissions": emissions,
|
| | "tag_loss": tag_loss,
|
| | "rel_logits": rel_logits,
|
| | }
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def predict(
|
| | self,
|
| | sents: List[str],
|
| | tokenizer=None,
|
| | *,
|
| | rel_mode: str = "neural_only",
|
| | rel_threshold: float = 0.8,
|
| | cause_decision: str = "cls+span",
|
| | ) -> List[dict]:
|
| | """End‑to‑end inference for causal sentence extraction (batched).
|
| |
|
| | Args:
|
| | sents: List of input sentences for causal extraction.
|
| | tokenizer: Tokenizer instance for encoding sentences. If None, a default tokenizer is initialized.
|
| | rel_mode: Strategy for relation extraction. "auto" mode simplifies relations when spans are limited.
|
| | rel_threshold: Probability threshold for relation head to reduce spurious pairs.
|
| | cause_decision: Strategy for determining causality ('cls_only', 'span_only', or 'cls+span').
|
| |
|
| | Returns:
|
| | List of dictionaries containing:
|
| | - "text": Original sentence.
|
| | - "causal": Boolean indicating if the sentence is causal.
|
| | - "relations": List of extracted causal relations.
|
| | """
|
| |
|
| |
|
| |
|
| | if tokenizer is None:
|
| | from transformers import AutoTokenizer
|
| | tokenizer = AutoTokenizer.from_pretrained(self.encoder_name, use_fast=True)
|
| |
|
| | device = next(self.parameters()).device
|
| | to_dev = lambda d: {k: v.to(device) for k, v in d.items()}
|
| |
|
| | outputs: List[dict] = []
|
| |
|
| |
|
| |
|
| |
|
| | enc = tokenizer(sents, return_tensors="pt", truncation=True, max_length=512, padding=True)
|
| | enc = to_dev(enc)
|
| |
|
| | with torch.no_grad():
|
| | base = self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
|
| |
|
| | cls_logits_batch = base["cls_logits"]
|
| | bio_emissions_batch = base["bio_emissions"]
|
| | input_ids_batch = enc["input_ids"]
|
| | attention_mask_batch = enc["attention_mask"]
|
| |
|
| | batch_size = input_ids_batch.size(0)
|
| |
|
| | for i in range(batch_size):
|
| | seq_len = attention_mask_batch[i].sum().item()
|
| | input_ids = input_ids_batch[i][:seq_len]
|
| | bio_emissions = bio_emissions_batch[i][:seq_len]
|
| | tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
| | bio_ids = bio_emissions.argmax(-1).tolist()
|
| | bio_labels = [id2label_bio[j] for j in bio_ids]
|
| |
|
| |
|
| | fixed_labels = self._apply_bio_rules(tokens, bio_labels)
|
| | spans = self._merge_spans(tokens, fixed_labels, tokenizer)
|
| |
|
| |
|
| | is_causal = self._decide_causal(cls_logits_batch[i], spans, cause_decision)
|
| |
|
| |
|
| |
|
| |
|
| | rels: List[dict] = []
|
| | pure_cause_spans = [s for s in spans if s.role == "C"]
|
| | pure_effect_spans = [s for s in spans if s.role == "E"]
|
| | ce_spans = [s for s in spans if s.role == "CE"]
|
| | cause_spans = pure_cause_spans + ce_spans
|
| | effect_spans = pure_effect_spans + ce_spans
|
| |
|
| | if cause_spans and effect_spans:
|
| |
|
| | has_pure_causes = len(pure_cause_spans) > 0
|
| | has_pure_effects = len(pure_effect_spans) > 0
|
| | has_ce_spans = len(ce_spans) > 0
|
| |
|
| | if has_ce_spans and not (has_pure_causes or has_pure_effects):
|
| | pass
|
| | elif rel_mode == "auto" and (len(cause_spans) == 1 or len(effect_spans) == 1):
|
| |
|
| | if len(cause_spans) == 1:
|
| | for e in effect_spans:
|
| | if (cause_spans[0].text.lower() != e.text.lower() or
|
| | (cause_spans[0].role == "CE" and e.role != "CE")):
|
| | rels.append({"cause": cause_spans[0].text, "effect": e.text, "type": "Rel_CE"})
|
| | else:
|
| | for c in cause_spans:
|
| | if (c.text.lower() != effect_spans[0].text.lower() or
|
| | (c.role == "CE" and effect_spans[0].role != "CE")):
|
| | rels.append({"cause": c.text, "effect": effect_spans[0].text, "type": "Rel_CE"})
|
| | elif rel_mode == "neural_only":
|
| |
|
| | pair_meta = []
|
| | for c in cause_spans:
|
| | for e in effect_spans:
|
| | if (not (c.start_tok == e.start_tok and c.end_tok == e.end_tok) or
|
| | (c.role == "CE" and e.role in {"C", "E"}) or
|
| | (c.role in {"C", "E"} and e.role == "CE")):
|
| | pair_meta.append((c, e))
|
| | if pair_meta:
|
| |
|
| | pair_batch = torch.zeros(len(pair_meta), dtype=torch.long, device=device)
|
| | cause_starts = torch.tensor([c.start_tok for c, _ in pair_meta], device=device)
|
| | cause_ends = torch.tensor([c.end_tok for c, _ in pair_meta], device=device)
|
| | effect_starts = torch.tensor([e.start_tok for _, e in pair_meta], device=device)
|
| | effect_ends = torch.tensor([e.end_tok for _, e in pair_meta], device=device)
|
| | rel_logits = self(
|
| | input_ids=input_ids.unsqueeze(0),
|
| | attention_mask=attention_mask_batch[i][:seq_len].unsqueeze(0),
|
| | pair_batch=pair_batch,
|
| | cause_starts=cause_starts,
|
| | cause_ends=cause_ends,
|
| | effect_starts=effect_starts,
|
| | effect_ends=effect_ends,
|
| | )["rel_logits"]
|
| | probs = torch.softmax(rel_logits, dim=-1)[:, 1].tolist()
|
| | for (c, e), p in zip(pair_meta, probs):
|
| | if p >= rel_threshold and c.text.lower() != e.text.lower():
|
| | rels.append({"cause": c.text, "effect": e.text, "type": "Rel_CE"})
|
| | else:
|
| |
|
| | pair_meta = []
|
| | for c in cause_spans:
|
| | for e in effect_spans:
|
| | if (not (c.start_tok == e.start_tok and c.end_tok == e.end_tok) or
|
| | (c.role == "CE" and e.role in {"C", "E"}) or
|
| | (c.role in {"C", "E"} and e.role == "CE")):
|
| | pair_meta.append((c, e))
|
| | if pair_meta:
|
| |
|
| | pair_batch = torch.zeros(len(pair_meta), dtype=torch.long, device=device)
|
| | cause_starts = torch.tensor([c.start_tok for c, _ in pair_meta], device=device)
|
| | cause_ends = torch.tensor([c.end_tok for c, _ in pair_meta], device=device)
|
| | effect_starts = torch.tensor([e.start_tok for _, e in pair_meta], device=device)
|
| | effect_ends = torch.tensor([e.end_tok for _, e in pair_meta], device=device)
|
| | rel_logits = self(
|
| | input_ids=input_ids.unsqueeze(0),
|
| | attention_mask=attention_mask_batch[i][:seq_len].unsqueeze(0),
|
| | pair_batch=pair_batch,
|
| | cause_starts=cause_starts,
|
| | cause_ends=cause_ends,
|
| | effect_starts=effect_starts,
|
| | effect_ends=effect_ends,
|
| | )["rel_logits"]
|
| | probs = torch.softmax(rel_logits, dim=-1)[:, 1].tolist()
|
| | for (c, e), p in zip(pair_meta, probs):
|
| | if p >= rel_threshold and c.text.lower() != e.text.lower():
|
| | rels.append({"cause": c.text, "effect": e.text, "type": "Rel_CE"})
|
| |
|
| | seen = set()
|
| | uniq = []
|
| | for r in rels:
|
| | key = (r["cause"].lower(), r["effect"].lower())
|
| | if key not in seen:
|
| | seen.add(key)
|
| | uniq.append(r)
|
| | rels = uniq
|
| |
|
| |
|
| | if not is_causal:
|
| | outputs.append({
|
| | "text": sents[i],
|
| | "causal": is_causal,
|
| | "relations": [],
|
| | "spans": [],
|
| | })
|
| | else:
|
| | outputs.append({
|
| | "text": sents[i],
|
| | "causal": is_causal,
|
| | "relations": rels,
|
| | })
|
| |
|
| | return outputs
|
| |
|
| |
|
| |
|
| |
|
| | @staticmethod
|
| | def _apply_bio_rules(tok: List[str], lab: List[str]) -> List[str]:
|
| | """Light‑touch BIO sanitiser that fixes **intra‑span role clashes** and
|
| | common WordPiece artefacts while deferring to model probabilities.
|
| |
|
| | Added rule (R‑6)
|
| | ----------------
|
| | When a contiguous non‑O block mixes **C** and **E** roles (e.g.
|
| | ``B‑C I‑C I‑E I‑C``) we collapse the entire block to the *majority*
|
| | role (ties prefer **C**). Only the first token keeps the ``B‑`` prefix.
|
| | """
|
| | n = len(tok)
|
| | out = lab.copy()
|
| |
|
| |
|
| | for i in range(1, n):
|
| | if tok[i].startswith("##") and out[i] == "O" and out[i-1] != "O":
|
| | role = out[i-1].split("-")[-1]
|
| | out[i] = f"I-{role}"
|
| |
|
| |
|
| | for i in range(n):
|
| | if out[i].startswith("I-") and (i == 0 or out[i-1] == "O"):
|
| | out[i] = out[i].replace("I-", "B-", 1)
|
| |
|
| |
|
| | for i in range(1, n):
|
| | if out[i].startswith("B-") and out[i-1] != "O":
|
| | role_prev = out[i-1].split("-")[-1]
|
| | role_curr = out[i].split("-")[-1]
|
| | if role_prev == role_curr:
|
| | out[i] = out[i].replace("B-", "I-", 1)
|
| |
|
| |
|
| |
|
| | roles_present = {tag.split("-")[-1] for tag in out if tag != "O"}
|
| | if "CE" in roles_present and "C" not in roles_present and "E" not in roles_present:
|
| |
|
| | for i, tag in enumerate(out):
|
| | if tag.endswith("CE"):
|
| | out[i] = tag[:-2] + "C"
|
| |
|
| |
|
| | i = 0
|
| | while i < n:
|
| | if out[i] == "O":
|
| | i += 1
|
| | continue
|
| | start = i
|
| | role_counts = {"C": 0, "E": 0, "CE": 0}
|
| | has_mixed_roles = False
|
| |
|
| |
|
| | while i < n and out[i] != "O" and not (i > start and out[i].startswith("B-")):
|
| | role = out[i].split("-")[-1]
|
| | role_counts[role] += 1
|
| | i += 1
|
| |
|
| |
|
| | non_ce_roles = set()
|
| | j = start
|
| | while j < i:
|
| | role = out[j].split("-")[-1]
|
| | if role in {"C", "E"}:
|
| | non_ce_roles.add(role)
|
| | j += 1
|
| |
|
| | if len(non_ce_roles) > 1:
|
| |
|
| | maj = "C" if role_counts["C"] >= role_counts["E"] else "E"
|
| | j = start
|
| | first = True
|
| | while j < i:
|
| | out[j] = ("B-" if first else "I-") + maj
|
| | first = False
|
| | j += 1
|
| | elif role_counts["CE"] > 0 and len(non_ce_roles) == 0:
|
| |
|
| | j = start
|
| | first = True
|
| | while j < i:
|
| | out[j] = ("B-" if first else "I-") + "CE"
|
| | first = False
|
| | j += 1
|
| | elif role_counts["CE"] > 0 and len(non_ce_roles) == 1:
|
| |
|
| |
|
| | other_roles = {tag.split("-")[-1] for tag in out if tag != "O"}
|
| | pure_role = list(non_ce_roles)[0]
|
| |
|
| | if (pure_role == "C" and "E" in other_roles) or (pure_role == "E" and "C" in other_roles):
|
| |
|
| | j = start
|
| | first = True
|
| | while j < i:
|
| | out[j] = ("B-" if first else "I-") + "CE"
|
| | first = False
|
| | j += 1
|
| | else:
|
| |
|
| | j = start
|
| | first = True
|
| | while j < i:
|
| | out[j] = ("B-" if first else "I-") + pure_role
|
| | first = False
|
| | j += 1
|
| |
|
| |
|
| | CONNECT = {"of", "to", "with", "for", "and", "or", "but", "in"}
|
| | for k in range(1, n - 1):
|
| | left_role = out[k - 1].split("-")[-1] if out[k - 1] != "O" else None
|
| | right_role = out[k + 1].split("-")[-1] if out[k + 1] != "O" else None
|
| | if not left_role or left_role != right_role:
|
| | continue
|
| |
|
| | if out[k] == "O" and tok[k].lower() in CONNECT:
|
| | out[k] = "I-" + left_role
|
| |
|
| | elif out[k] == "O" and len(tok[k]) == 1 and not tok[k].isalnum():
|
| | out[k] = "I-" + left_role
|
| | elif out[k].startswith("I-") and out[k].split("-")[-1] != left_role:
|
| | out[k] = "I-" + left_role
|
| |
|
| |
|
| |
|
| |
|
| | b_positions = {}
|
| | for i, label in enumerate(out):
|
| | if label.startswith("B-"):
|
| | role = label.split("-")[1]
|
| | if role not in b_positions:
|
| | b_positions[role] = []
|
| | b_positions[role].append(i)
|
| |
|
| | for role, positions in b_positions.items():
|
| | if len(positions) < 2:
|
| | continue
|
| |
|
| |
|
| | groups = []
|
| | current_group = [positions[0]]
|
| |
|
| | for i in range(1, len(positions)):
|
| | prev_pos = positions[i-1]
|
| | curr_pos = positions[i]
|
| | gap_size = curr_pos - prev_pos - 1
|
| |
|
| | if gap_size <= 1:
|
| | gap_labels = out[prev_pos + 1:curr_pos]
|
| | if all(label == "O" for label in gap_labels):
|
| | current_group.append(curr_pos)
|
| | else:
|
| | groups.append(current_group)
|
| | current_group = [curr_pos]
|
| | else:
|
| | groups.append(current_group)
|
| | current_group = [curr_pos]
|
| |
|
| | groups.append(current_group)
|
| |
|
| |
|
| | for group in groups:
|
| | if len(group) > 1:
|
| | first_pos = group[0]
|
| | last_pos = group[-1]
|
| |
|
| | for pos in range(first_pos + 1, last_pos + 1):
|
| | if pos in group[1:]:
|
| | out[pos] = f"I-{role}"
|
| | elif out[pos] == "O":
|
| | out[pos] = f"I-{role}"
|
| |
|
| | return out
|
| |
|
| |
|
| | @staticmethod
|
| | def _merge_spans(tok: List[str], lab: List[str], tokenizer) -> List["Span"]:
|
| | """Turn cleaned BIO labels into Span objects.
|
| |
|
| | Policy:
|
| | 1. **First pass** – assemble raw spans, letting them bridge a single
|
| | connector (of, to, with, for, and, or, but, in).
|
| | 2. **Trim** leading/trailing connectors & punctuation.
|
| | 3. **Normalise** hyphen spacing & strip quotes.
|
| | 4. **Role‑wise pruning** – if a role has ≥1 span with *≥2 words*, drop
|
| | *all* its 1‑word spans. This removes stray nouns like "choices"
|
| | while preserving them when they are the *only* cause/effect.
|
| | """
|
| | CONNECT = {"of", "to", "with", "for", "and", "or", "but", "in"}
|
| |
|
| | spans: List[Span] = []
|
| | i, n = 0, len(tok)
|
| | while i < n:
|
| | if lab[i] == "O":
|
| | i += 1; continue
|
| | role = lab[i].split("-")[-1]
|
| | s = i
|
| | i += 1
|
| | while i < n:
|
| | if lab[i].startswith("I-"):
|
| | i += 1; continue
|
| | if tok[i].lower() in CONNECT and lab[i] == "O" and i+1 < n and lab[i+1].startswith("I-"):
|
| | i += 1; continue
|
| | break
|
| | e = i - 1
|
| | text = tokenizer.convert_tokens_to_string(tok[s:e+1])
|
| |
|
| | text = text.replace(" - ", "-").replace(" -", "-").replace("- ", "-")
|
| | text = text.strip("\"'”’““”")
|
| | words = text.split()
|
| | while words and words[0].lower() in CONNECT:
|
| | words.pop(0)
|
| | while words and words[-1].lower() in CONNECT:
|
| | words.pop()
|
| | if not words:
|
| | continue
|
| | clean_text = " ".join(words)
|
| | spans.append(Span(role, s, e, clean_text))
|
| | from collections import defaultdict, OrderedDict
|
| | import re
|
| | by_role = defaultdict(list)
|
| | for sp in spans:
|
| | by_role[sp.role].append(sp)
|
| | final: List[Span] = []
|
| | for role, group in by_role.items():
|
| | has_multi = any((g.end_tok - g.start_tok) >= 1 for g in group)
|
| | for sp in group:
|
| | single_tok = (sp.end_tok - sp.start_tok) == 0
|
| |
|
| |
|
| | if single_tok:
|
| |
|
| | is_meaningful = (
|
| | len(sp.text) > 2 and
|
| | sp.text.isalpha() and
|
| | not sp.text.lower() in {"this", "that", "it", "they", "them", "he", "she", "we", "i", "you"}
|
| | )
|
| | if not is_meaningful and has_multi:
|
| |
|
| | if role == "C" or role == "E":
|
| | continue
|
| | final.append(sp)
|
| | final.sort(key=lambda s: s.start_tok)
|
| |
|
| | merged: List[Span] = []
|
| | def is_punct(tok):
|
| | return len(tok) == 1 and not tok.isalnum()
|
| | for sp in final:
|
| | if merged and sp.role == merged[-1].role:
|
| | gap_tokens = tok[merged[-1].end_tok + 1 : sp.start_tok]
|
| | if gap_tokens and all(is_punct(t) for t in gap_tokens):
|
| |
|
| | combined_text = tokenizer.convert_tokens_to_string(tok[merged[-1].start_tok: sp.end_tok + 1]).strip("\"'”’““”")
|
| | merged[-1] = Span(sp.role, merged[-1].start_tok, sp.end_tok, combined_text)
|
| | continue
|
| | merged.append(sp)
|
| | return merged
|
| |
|
| | def _decide_causal(self, cls_logits, spans, cause_decision):
|
| | """Determine if a sentence is causal based on classification logits and spans.
|
| |
|
| | Args:
|
| | cls_logits: Tensor of classification logits
|
| | spans: List of extracted spans
|
| | cause_decision: Strategy for determining causality ('cls_only', 'span_only', or 'cls+span')
|
| |
|
| | Returns:
|
| | bool: True if the sentence is determined to be causal
|
| | """
|
| | prob_causal = torch.softmax(cls_logits, dim=-1)[1].item()
|
| |
|
| |
|
| | has_cause_spans = any(x.role in ("C", "CE") for x in spans)
|
| | has_effect_spans = any(x.role in ("E", "CE") for x in spans)
|
| | has_both_spans = has_cause_spans and has_effect_spans
|
| |
|
| | if cause_decision == "cls_only":
|
| | return prob_causal >= 0.5
|
| | elif cause_decision == "span_only":
|
| | return has_both_spans
|
| | else:
|
| | return prob_causal >= 0.5 and has_both_spans |