Core Pipeline
Distillation: Loss Strategies
Supported Tasks
Training Paradigms
Workbench (Post-Hoc Analysis)

SFT

Fine-Tuning

Standard supervised
learning with CE loss

alpha: 0.0
FinetuningTrainer

Simplest baseline

KD

Knowledge Distillation

Large teacher →
small student

alpha > 0.0
DistillationTrainer

9 pluggable loss strategies

RL

Reinforcement Learning

Post-training via
GRPO optimization

rl: config block
trl.GRPOTrainer

Optimise toward rewards

Classification

MNLI · QQP
SST-2 · CoLA

Encoder-only
BERT → DistilBERT

Accuracy · F1 · MCC

Translation

Europarl EN-FR

Decoder-only
Llama 3.2 3B → 1B

Prompt masking
Metric: BLEU

OCR / Doc AI

olmOCR · OmniDocBench
Synthmix

Vision-Language
Qwen3-VL 3B → 2B

Edit similarity · Exact match

Research Goal

FAR Factorial Study:

Systematically explore all
Feature + Attention + Response
loss combinations

→ Which distillation strategy
works best for each task?

Planned extensions:
Diffusion · Voice
Code Gen · Summarisation

YAML Config

Hierarchical inheritance
All hyperparameters

task + model + mode
lr · batch · temperature
alpha · warmup...

Data Loading

GLUE Handler
Europarl Handler
OCR Handler

Tokenise or process
Prompt masking
Batch collation

Evaluation

Direct eval loop
(no nested Trainer)

Pre + post metrics
Transfer gap analysis:
· teacher advantage
· student gain
· transfer efficiency

Outputs

outputs/.../timestamp/

· config.yaml
· run_metadata.json
· output.log
· checkpoint-N/
· merged final model

W&B + SLURM

All metrics to W&B
Loss component breakdown

SLURM auto-requeue
Graceful SIGUSR1 stop
Atomic checkpoints
Auto-resume from latest

Response-Based
(output logits only)

response_kl
response_reverse-kl
response_mse
response_cosine
response_ce

Feature-Based
(hidden states)

feature_mse
(FitNets)

feature_cosine
(DistilBERT-style)

Attention-Based
attention_kl

Composite

combined_*
8 FAR combos:
F · A · R · FR
FA · AR · FAR · CE

staged_*
Progressive distillation
over training steps

Forge

Universal, config-driven framework for training deep learning models.

One command: python forge.py --config path/to/config.yaml

Every experiment = a YAML config change, not a code change.

Model Setup

Load student + teacher
LoRA adapter config
Frozen or co-train
Auto embed resize
Attention implementation

Training Loop

CE + α·Distillation
or GRPO rewards

Callbacks:
· SLURM signal handler
· checkpoint marker
· peak metrics tracker

Tech Stack

PyTorch 2.6+
HF Transformers 5.2+
PEFT 0.18+ (LoRA)
TRL 0.29+ (GRPO)
Accelerate · Datasets
Evaluate · SacreBLEU
Pillow · olmocr
W&B · SLURM · conda

Layer Mapping

uniform — spaced
last_n — last N
first_last
custom — explicit

Projections
identity
linear · mlp

Linear Probing

Freeze hidden states
at each layer

Where does task info
emerge in the model?
Teacher vs student
layer-by-layer

CKA Analysis

Centered Kernel
Alignment

Which student layers
align with which
teacher layers?

Validates layer mapping

Disagreement

Where does student
deviate from teacher?

Head Fine-Tuning

Diagnose:
capacity issue vs
alignment issue

Config Hierarchy

base.yaml
  └─ classification/
       cola/base.yaml
         └─ distilbert/base.yaml
              └─ bert-base/base.yaml
                   └─ ablations/
                     01_kd_frozen-base.yaml

Each level overrides parent.
Final config = merged chain.
One YAML per experiment.

Teacher Strategies

  1. Frozen SFT teacher
  2. Frozen Base teacher
  3. Co-train SFT teacher
  4. Co-train Base teacher

Best overall:
Co-train Base

Risk: Frozen SFT
overconfident on
imbalanced data