Methodik: Das PRISMS Framework – Modulare Fusion für Robuste Segmentierung
1. Grundlagen der PRISMS-Architekturbausteine
In diesem Abschnitt werden die theoretischen Grundlagen und Implementierungsdetails der einzelnen Module erläutert, die das PRISMS-Framework bilden. PRISMS ist darauf ausgelegt, multimodale Daten robust zu fusionieren und präzise Segmentierungen zu ermöglichen, auch bei unvollständigen Datensätzen.
1.1. Modality Encoder
Konzept und Mathematische Beschreibung: Der Modality Encoder dient der Extraktion tiefgehender, hierarchischer Merkmale für jede einzelne Eingangsmodalität \( m \). Ausgehend von einem Rohdateneingang \( x^m_0 \) (z.B. ein 3D-MRT-Volumen) wird dieser durch eine Sequenz von Faltungsblöcken \( E_m^{(l)} \) auf verschiedenen Ebenen \( l \) transformiert, um eine Repräsentation \( x^m_L \) auf der tiefsten Ebene sowie Skip-Connection-Features \( x^m_l \) für höhere Auflösungen zu generieren:
\[ x^m_{l+1} = E_m^{(l+1)}(x^m_l) \]Jeder Block \( E_m^{(l)} \) besteht typischerweise aus Faltungsoperationen, Normalisierung und Aktivierungsfunktionen, oft ergänzt durch Residualverbindungen. In unserer Implementierung werden general_conv3d
-Blöcke mit pad_type='reflect'
verwendet. Die strikt separate Verarbeitung der Modalitäten in den Encodern ist fundamental, um die Robustheit gegenüber fehlenden Modalitäten zu gewährleisten.
Codeausschnitt (Illustrativ für einen Encoder-Block):
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.e1_c1 = general_conv3d(1, basic_dims, pad_type='reflect')
self.e1_c2 = general_conv3d(basic_dims, basic_dims, pad_type='reflect')
self.e1_c3 = general_conv3d(basic_dims, basic_dims, pad_type='reflect')
# ... weitere Blöcke ...
def forward(self, x):
x1 = self.e1_c1(x)
x1 = x1 + self.e1_c3(self.e1_c2(x1))
# ... Verarbeitung ...
return x1, x2, x3, x4, x5
Dieser Code demonstriert die initialen Faltungsschichten und eine Residualverknüpfung.
1.2. Adaptive Fusion Transformer (AFT) im PRISM Bottleneck
Konzept und Mathematische Beschreibung: Der Adaptive Fusion Transformer (AFT) ist das zentrale Modul im sogenannten PRISM Bottleneck. Er integriert initial Informationen aus allen verfügbaren Modalitäten und lernbaren Fusions-Tokens. Die tiefsten Merkmals-Tensoren \( x^m_L \) aus den Encodern werden für jede Modalität \( m \) zu Token-Sequenzen \( F_m \in \mathbb{R}^{B \times N \times C'} \) entfaltet. Zusätzlich werden lernbare Fusions-Tokens \( F_{\mathrm{fusion}} \in \mathbb{R}^{B \times N_f \times C'} \) verwendet.
Diese Token-Mengen werden konkateniert und mit lernbaren Positions-Embeddings \( P \) additiv kombiniert, um den Eingabe-Token-Satz \( Z_0 \) für den AFT zu bilden:
\[ Z_0 = \mathrm{Concat}\left( F_1, \dots, F_M, F_{\mathrm{fusion}} \right) + P \]Dieser Satz \( Z_0 \) wird dann durch \( L_1 \) Transformer-Blöcke des AFT verarbeitet. Jeder Block wendet Mechanismen wie die Modalitätsmaskierte Attention (MMA) (siehe Abschnitt 1.5) an, um globale Abhängigkeiten unter Berücksichtigung einer Eingangsmaske \(\mathcal{M}\) zu lernen:
\[ Z_{l+1} = \mathrm{AFTBlock}_l(Z_l, \mathcal{M}) \]Der resultierende Tensor \( Z_{L_1} \) wird in transformierte Modalitäts-Features \( F'_m \), transformierte Fusions-Tokens \( F'_{\mathrm{fusion}} \) und Attention-Matrizen (insbesondere \(\text{Attn}_1\) für die SRA) aufgeteilt.
Codeausschnitt (Illustrativ für die AFT-Verarbeitung):
embed_cat = torch.cat((embed_flair, embed_t1ce, embed_t1, embed_t2, fusion), dim=1)
embed_cat = embed_cat + pos
# self.trans_bottle ist der MaskedTransformer (AFT-Block)
embed_cat_trans, attn_matrices = self.trans_bottle(embed_cat, mask)
flair_trans, t1ce_trans, t1_trans, t2_trans, fusion_trans = \
torch.chunk(embed_cat_trans, num_modals + 1, dim=1)
Der Code illustriert die Token-Zusammenführung, Positions-Embedding-Addition, Transformer-Verarbeitung (AFT) und die Aufteilung der resultierenden Tokens.
1.3. Spatial Weight Attention (SRA)
Konzept und Mathematische Beschreibung: Die Spatial Relevnce Attention (SRA) quantifiziert und appliziert die räumliche Wichtigkeit von Merkmalsregionen. Sie nutzt die im ersten Layer des AFT berechnete Attention-Matrix \(\text{Attn}_1\). Für jeden ursprünglichen Modalitäts-Token \( j \) wird dessen räumliche Relevanz als Summe seiner Attention-Gewichte zu allen Fusions-Tokens \( i \) berechnet:
\[ \text{Relevance}(j) = \sum_{H_a} \sum_{i \in \text{FusionTokens}} \text{Attn}_1(i, j) \]Diese Relevanzwerte werden für jede Modalität \( m \) zu einer räumlichen Wichtigkeitskarte \( I_m \in \mathbb{R}^{B \times 1 \times H \times W \times D} \) umgeformt.
Diese Karten \( I_m \) werden dann elementweise mit den transformierten Modalitäts-Features \( F'_m \) (Ausgabe des AFT) und den ursprünglichen Encoder-Skip-Connection-Features \( x^m_l \) auf verschiedenen Ebenen multipliziert (nach Upsampling der \( I_m \)-Karten):
\[ \widetilde{F}'_m = F'_m \odot I_m \]\[ \widetilde{x}^m_l = x^m_l \odot \mathrm{Upsample}(I_m, \text{scale}_l) \]Fehlende Modalitäten führen zu einer Null-Karte \( I_m \). Die resultierenden gewichteten Features \(\widetilde{F}'_m\) und \(\widetilde{x}^m_l\) werden weiterverarbeitet.
Codeausschnitt (Konzeptuell für die SRA-Gewichtung):
# attn_matrices[0] ist Attn_1 aus dem ersten AFT-Layer
# ... (Details der Indexierung und Summierung zur Erzeugung von spatial_importance_map_m) ...
# Beispielhafte Anwendung:
weighted_feature_map_mod_m = feature_map_mod_m * spatial_importance_map_m
Die SRA gewichtet Features aus dem AFT und Skip-Connections, um räumlich relevante Informationen hervorzuheben.
1.4. Cross-Modal Fusion Transformer (KFT)
Konzept und Mathematische Beschreibung:
Der Kanalbezogener Fusion Transformer (KFT), im Code durch MultiCrossToken
-Blöcke (self.CTx
) realisiert, ist für die fortgeschrittene Fusion und Re-Kalibrierung der modalitätsspezifischen Informationen auf verschiedenen Ebenen des Decoders zuständig. Er ermöglicht eine tiefe Interaktion zwischen den Token-Strömen der verschiedenen Modalitäten und den transformierten Fusions-Tokens.
Auf jeder Decoder-Ebene \( d \) nimmt ein KFT-Block \( \text{KFT}_d \) die Fusions-Features \( \mathcal{F}_{\text{fusion}}^{(d-1)} \) von der vorherigen Ebene und die SRA-gewichteten Skip-Connection-Features der einzelnen Modalitäten \( \widetilde{x}^m_l \) als Eingabe. Der KFT-Block besteht aus \( L_2 \) Schichten, die typischerweise Cross-Attention- und Self-Attention-Mechanismen unter Verwendung der MMA (siehe Abschnitt 1.5) beinhalten:
- Cross-Attention: Fusions-Features \( \mathcal{F}_{\text{fusion}}^{(d-1)} \) als Queries, konkatenierte modalitätsspezifische Features \( \mathrm{Concat}(\{\widetilde{x}^m_l\}_M) \) als Keys/Values.
- Self-Attention: Zur Verfeinerung der aggregierten Fusions-Features.
Die Ausgabe ist ein Satz verfeinerter Fusions-Features \( \mathcal{F}_{\text{fusion}}^{(d)} \) für die aktuelle Decoder-Ebene:
\[ \mathcal{F}_{\text{fusion}}^{(d)}, \text{UpdatedModalFeatures}^{(d)} = \mathrm{KFT}_d(\mathcal{F}_{\text{fusion}}^{(d-1)}, \{\widetilde{x}^m_l\}_M, \mathcal{M}) \]Codeausschnitt (Illustrativ für einen MultiCrossToken
-Block als KFT):
class MultiCrossToken(nn.Module): # Repräsentiert einen KFT-Block
# ... init mit MultiMaskCrossBlock Layers ...
def forward(self, inputs_modal_features, kernel_fusion_features, mask):
# inputs_modal_features: SRA-gewichtete Skip-Features
# kernel_fusion_features: upgesampelte Fusions-Features von tieferer Ebene
# MultiMaskCrossBlock führt MMA-basierte Cross-Attention und FFN durch
# ... (Loop über Layers) ...
return kernels # Verfeinerte Fusions-Features
Der KFT (MultiCrossToken
) verarbeitet modalitätsspezifische und Fusions-Features unter Berücksichtigung der Maske, um die Fusions-Features auf jeder Decoder-Ebene zu aktualisieren.
1.5. Modalitätsmaskierte Attention (MMA) (Modality Masked Attention)
Konzept und Mathematische Beschreibung: Die Modalitätsmaskierte Attention (MMA) ist ein fundamentaler Mechanismus in den Transformer-Blöcken von PRISMS (AFT und KFT). Sie stellt sicher, dass Attention-Berechnungen nur zwischen Tokens von tatsächlich vorhandenen Modalitäten oder zwischen Modalitäts-Tokens und Fusions-Tokens stattfinden. Die Standard-Attention ist:
\[ \text{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]Die MMA modifiziert die Attention-Scores \( A = \frac{QK^T}{\sqrt{d_k}} \) vor der Softmax-Operation. Eine binäre Attention-Maske \(\mathcal{A_M}\), basierend auf der Eingangsmaske \(\mathcal{M}\), wird generiert. Scores für ungültige Interaktionen (\(\mathcal{A_M}(i,j) = 0\)) werden auf \(-\infty\) gesetzt:
\[ A'_{ij} = \begin{cases} A_{ij} & \text{if } \mathcal{A_M}(i,j) = 1 \\ -\infty & \text{if } \mathcal{A_M}(i,j) = 0 \end{cases} \]\[ \text{MMA}(Q, K, V) = \mathrm{softmax}(A')V \]Dies verhindert, dass Tokens fehlender Modalitäten die Repräsentationen beeinflussen.
Codeausschnitt (Illustrativ für MMA im AFT/MaskedAttention
):
attn_scores = (q @ k.transpose(-2, -1)) * self.scale
# self_mask_for_attn wird basierend auf der Eingangsmaske 'mask' generiert
# z.B. durch mask_gen_fusion für AFT oder mask_gen_cross4 für KFT
self_mask_for_attn = mask_generator_function(B, ..., mask)
attn_scores = attn_scores.masked_fill(self_mask_for_attn == 0, float("-inf"))
attn_probs = attn_scores.softmax(dim=-1)
# ... attn_probs @ v ...
Der Code zeigt die Modifikation der Attention-Scores mittels masked_fill
zur Realisierung der MMA.
2. Der PRISM-Distillationsmechanismus in PRISMS
Das PRISMS-Framework integriert den PRISM-Distillationsansatz, um Wissen von einem starken “Teacher” (der multimodalen Fusionsrepräsentation) auf “Student” (effektiv unimodale Verarbeitungspfade innerhalb des Fusionsnetzwerks) zu übertragen.
2.1. Pixel-weiser KL Divergence Loss
Konzept: Angleichung der Vorhersageverteilungen auf Pixelebene. Der Teacherpfad (Haupt-Fusionspipeline von PRISMS) erzeugt Logits \( z^t_n \). Für jeden Studentpfad \( m \) (konzeptionell erzeugt durch Verarbeitung nur der Modalität \( m \) durch die Fusionspipeline mit unimodaler Maske) werden Logits \( z^m_n \) erzeugt. Der Verlust ist:
\[ L^{m}_{\text{pixel-KL}} = \sum_n D_{\mathrm{KL}}\!\left[\sigma(z^t_n/\tau)\,\middle\|\,\sigma(z^m_n/\tau)\right] \]Dies wird für finale Vorhersagen und Aux Heads berechnet.
2.2. Prototype Alignment Loss (Semantische Distillation)
Konzept: Angleichung der semantischen Klassenrepräsentationen. Für jede Klasse \( k \) werden Prototyp-Vektoren durch räumliche Aggregation der Merkmale des Teachers (\(f^t_n\)) und des Students \( m \) (\(f^m_n\)) berechnet:
\[ S^{t}_{k} = \mathrm{Pool}_{\text{class } k}(f^t_n), \quad S^{m}_{k} = \mathrm{Pool}_{\text{class } k}(f^m_n) \]Der Verlust minimiert die \(L_2\)-Distanz zwischen den Prototypen:
\[ L^{m}_{\text{align}} = \sum_{k} \bigl\|S^t_{k} - S^m_{k}\bigr\|_2^2 \]Rationale und Implementierung in PRISMS: Die Gradienten aus dem kombinierten PRISM-Verlust \(L_{\text{prism}}\) aktualisieren die Gewichte der gemeinsam genutzten Komponenten (AFT, SRA, KFT, Haupt-Fusionsdecoder). Mechanismen zur “Preference-Aware Re-Weighting” können integriert werden.
Codeausschnitt (Prototype Alignment Loss):
align_loss_flair, dist_flair = prototype_alignment_loss(
student_features=de_f_flair[0], # Merkmale des Flair-Students
teacher_features=de_f_avg[0].detach(),# Merkmale des Teachers (kein Gradient)
target_labels=target,
student_logits=fuse_pred_flair, # Logits des Flair-Students
teacher_logits=fuse_pred.detach(), # Logits des Teachers (kein Gradient)
# ...
)
# L_prism = Summe aller align_loss_m und kl_loss_m
Der Code zeigt die Berechnung des Prototype Alignment Loss für einen Studentpfad.
3. Zusammenfassung der PRISMS-Komponenten und des PRISM-Mechanismus
- Modality Encoder: Extrahiert unabhängige, hierarchische Merkmale pro Modalität.
- Adaptive Fusion Transformer (AFT) im PRISM Bottleneck: Initiale globale Fusion von Modalitäts- und Fusions-Tokens mittels MMA.
- Spatial Weight Attention (SRA): Gewichtung der räumlichen Relevanz von AFT-Ausgaben und Skip-Connections basierend auf AFT-Attention.
- Cross-Modal Fusion Transformer (KFT): Tiefe Interaktion und Re-Kalibrierung von Fusions- und Skip-Connection-Features auf Decoder-Ebenen mittels MMA (realisiert durch
MultiCrossToken
-Blöcke). - Modalitätsmaskierte Attention (MMA): Kernmechanismus in AFT und KFT zur Berücksichtigung nur valider Token-Paare.
- Haupt-Fusionsdecoder: Verarbeitet KFT-Ausgaben und SRA-gewichtete Skips; dient als “Teacher”.
- PRISM-Distillationsmechanismus: Überträgt Wissen vom Teacher auf konzeptionelle Studentpfade mittels Pixel-KL- und Prototype-Alignment-Loss.
- Regulation Decoders: Separate Decoder pro Modalität für \(L_{\text{reg}}\).