Knowledge Distillation at a Low Level

Hoyath
5 min readJan 24, 2025

--

We’ve all heard about knowledge distillation and how it helps in making models smaller by sacrificing a bit of performance. Essentially, it involves distilling information from a larger parent model into a comparatively smaller student model. Let’s delve into how this knowledge transfer works.

Understanding Knowledge Distillation

In knowledge distillation, the goal is to transfer the “knowledge” from a parent model (teacher) to a smaller model (student). Both large language models (LLMs) and other types of models typically output distributions of probabilities after applying the softmax function.

Example of Softmax Output

Suppose we have a neural network model that outputs class probabilities for three classes after applying softmax. Consider the following logits from the parent model:

  • Parent Output Logits: [1.1, 0.2, 0.2]

After applying softmax, these logits become:

  • Softmax Output: [0.552, 0.224, 0.224]

Here, class 0 has the highest probability, making it the predicted class. However, the model also assigns lower probabilities to classes 1 and 2. This indicates that while class 0 is the most likely, there are features in the input data that suggest the input might also belong to classes 1 or 2 with lower confidence.

Leveraging Lower Probability Information

Typically, we ignore the lower probability classes since the highest probability (0.552) is significantly larger than the others (0.224 each). However, the idea behind knowledge distillation is to utilize this additional information to train the student model.

Horse and Deer Analogy:

Imagine the parent model is tasked with identifying animals and it encounters an image of a horse. The model outputs the highest probability for “horse” but also assigns some probability to “deer” and “cow.” This happens because horses and deer share common features , both have four legs and a tail. However, distinguishing features like the horse’s larger size and distinct head shape make “horse” the more probable class. By recognizing these similarities, the model acknowledges that while the input is most likely a horse, there are aspects that it shares with other classes like deer and cow.

Consider another scenario where the parent model’s logits are:

  • Parent Output Logits: [2.9, 0.1, 0.23]

After applying softmax:

  • Softmax Output: [0.885, 0.054, 0.061]

Here, class 0 dominates with a probability of 0.885, but the other classes still retain some information. To capture more nuanced information, we can soften the distribution by applying a temperature T=3. The softened logits become:

  • Softened Logits: [0.967, 0.033, 0.077]

After applying softmax again:

  • Softmax with Temperature: [0.554, 0.218, 0.228]

This softened distribution retains information about the dominant class while amplifying the probabilities of the other classes. These are referred to as soft probabilities, which contain richer information compared to the hard labels like [1, 0, 0].

Training the Student Model

When training a smaller student model, using only the hard labels (e.g., [1, 0, 0]) means the model is solely focused on predicting the correct class. The loss function typically used here is the cross-entropy loss. However, with knowledge distillation, we also incorporate the soft probabilities from the parent model to provide additional information.

Combined Loss Function

The total loss for the student model consists of two parts:

  1. Hard Loss: This is the standard cross-entropy loss between the student’s predictions and the true labels.
  2. Soft Loss: This is the loss calculated using the parent model’s soft probabilities.

Mathematically, the combined loss can be expressed as:

Calculating KL Divergence

To quantify the difference between the parent’s soft probabilities and the student’s predictions, we use Kullback-Leibler (KL) divergence:

Where:

  • pip_i are the parent model’s soft probabilities.
  • qiq_i are the student model’s predicted probabilities.

Example Calculation

Let’s calculate the KL divergence between the teacher and student predictions:

  • Teacher Soft Probabilities: [0.554,0.218,0.228]
  • Student Soft Probabilities: [0.26,0.32,0.42]

Calculating each term:

Summing these up:

Final Loss Calculation

To account for the temperature scaling, we adjust the KL divergence by multiplying it by T²:

This adjustment ensures that the magnitude of the KL divergence does not diminish too much, which could otherwise lead to tiny gradient steps during backpropagation. By incorporating both the hard loss and the scaled KL divergence, the student model benefits from the additional information provided by the teacher model, facilitating more efficient learning.

Finally

Using only hard labels (e.g., [1, 0, 0]) requires the student model to put in a lot of effort to learn. However, by incorporating the additional information from the teacher model’s soft probabilities, the student model can learn more effectively and efficiently. This makes knowledge distillation a powerful technique for creating smaller, yet highly capable models.

I hope this explanation clarifies how knowledge distillation works at a low level. If you have any questions or need further clarification, feel free to ask!

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

Hoyath
Hoyath

Written by Hoyath

Masters in Computer Science, University of Riverside, California. Ex- Analyst at Goldman Sachs

Responses (1)

Write a response