PRISMS D Loss

    graph TD
    %% Style definitions
    classDef losses fill:#f88,stroke:#333,stroke-width:2px,color:#fff;
    classDef regulation fill:#ffb380,stroke:#333,stroke-width:2px,color:#333;
    classDef prism fill:#9f9,stroke:#333,stroke-width:2px,color:#333;
    classDef outputs fill:#ffcc99,stroke:#333,stroke-width:2px,color:#333;
    
    %% PART D: LOSS COMPUTATION
    
    %% Model Outputs
    subgraph OUTPUTS["Network Outputs"]
        MAIN_SEG["Main Segmentation<br/>(Teacher Path)"]:::outputs
        AUX_SEG["Auxiliary Outputs<br/>(PRM1-PRM5)"]:::outputs
        STUDENT_SEG["Student Segmentations<br/>(Flair, T1ce, T1, T2)"]:::outputs
        STUDENT_AUX["Student Auxiliary Outputs<br/>(Student PRMs)"]:::outputs
    end
    
    %% Progressive Weighting
    subgraph WEIGHTING["Progressive Weighting Mechanism"]
        WEIGHTS["Weight Initialization<br/>(weight_prm = 1.0)"]:::regulation
        
        WEIGHT1["PRM1 Weight = 1.0"]:::regulation
        WEIGHT2["PRM2 Weight = 0.5"]:::regulation
        WEIGHT3["PRM3 Weight = 0.25"]:::regulation
        WEIGHT4["PRM4 Weight = 0.125"]:::regulation
        WEIGHT5["PRM5 Weight = 0.0625"]:::regulation
    end
    
    %% Loss Components 
    subgraph LOSS_COMPONENTS["Loss Components"]
        %% Main Segmentation Loss
        MAIN_LOSS["Main Segmentation Loss<br/><br/>• Dice Loss<br/>• Cross-Entropy Loss"]:::losses
        
        %% Auxiliary (PRM) Loss
        AUX_LOSS["Auxiliary (PRM) Loss<br/><br/>• Weighted by scale<br/>• Applied to teacher & students"]:::losses
        
        %% PRISM-specific Losses
        KL_LOSS["KL Divergence Loss<br/><br/>• Pixel-wise distillation<br/>• Applied at each scale<br/>• Temperature-scaled"]:::prism
        
        PROTO_LOSS["Prototype Loss<br/><br/>• Feature-level distillation<br/>• Class-wise feature alignment<br/>• Cosine similarity matching"]:::prism
    end
    
    %% Combined Loss
    FINAL_LOSS["Final Combined Loss<br/><br/>L = L_seg + L_aux + L_kl + L_proto"]:::losses
    
    %% Code Example
    CODE_EXAMPLE["PROGRESSIVE WEIGHT CODE SNIPPET:<br/><br/>weight_prm = 1.0<br/>for prm_pred, prm_pred_flair, up_op in zip(preds, preds_flair, self.up_ops):<br/>    weight_prm /= 2.0<br/>    kl_loss += masks_mod0 * weight_prm * temp_kl_loss_bs(<br/>        prm_pred_flair,    # Student PRM<br/>        prm_pred.detach(), # Teacher PRM (detached)<br/>        target, num_cls=num_cls, temp=temp, up_op=up_op<br/>    )"]
    style CODE_EXAMPLE fill:#eee,stroke:#333,stroke-width:1px,color:#333,text-align:left,font-family:monospace
    
    %% Connect Components
    MAIN_SEG --> MAIN_LOSS
    AUX_SEG --> AUX_LOSS
    
    WEIGHTS --> WEIGHT1
    WEIGHTS --> WEIGHT2
    WEIGHTS --> WEIGHT3
    WEIGHTS --> WEIGHT4
    WEIGHTS --> WEIGHT5
    
    WEIGHT1 --> AUX_LOSS
    WEIGHT2 --> AUX_LOSS
    WEIGHT3 --> AUX_LOSS
    WEIGHT4 --> AUX_LOSS
    WEIGHT5 --> AUX_LOSS
    
    STUDENT_SEG --> KL_LOSS
    MAIN_SEG --> KL_LOSS
    
    STUDENT_AUX --> KL_LOSS
    AUX_SEG --> KL_LOSS
    
    STUDENT_SEG --> PROTO_LOSS
    MAIN_SEG --> PROTO_LOSS
    
    MAIN_LOSS --> FINAL_LOSS
    AUX_LOSS --> FINAL_LOSS
    KL_LOSS --> FINAL_LOSS
    PROTO_LOSS --> FINAL_LOSS
    
    %% Technical Details
    TECH_DETAILS["LOSS CALCULATION DETAILS:<br/><br/>• Segmentation loss: softmax_weighted_loss_bs + dice_loss_bs<br/>• KL Distillation: temp_kl_loss_bs (temperature-scaled)<br/>• Prototype loss: prototype_prism_loss_bs (feature consistency)<br/>• Progressive weights halve at each decoder level (1.0→0.0625)<br/>• Student losses only applied to visible modalities (via masking)"]
    style TECH_DETAILS fill:#eee,stroke:#333,stroke-width:1px,color:#333