Passing Images into LLMs

Introduction

Often when I work with LLMs it is with text inputs, but I also use image inputs too sometimes. I continue to spend time learning the internals of the transformer architecture used in decoder style LLMs. But most of my studying has been with text inputs. I was always a little curious, and confused, on how images were passed into LLMs. I just never took the time to dig into it more. Now I am finally getting around to it. This blog post is some notes I'm taking as I learn more about this topic. I write about things so I can better understand them, and my future self is always grateful.

The main motivation for this post was Sebastian Raschka's recent blog post Understanding Multimodal LLMs. I highly recommend reading it. I'm going to focus on learning a subset of the topics that his blog post discusses. I am starting with focusing on what he refers to as Method A: Unified Embedding Decoder Architecture. I may go deeper on other topics in other blog posts, but for now this is a good start for me.

High Level Overview

Large Language Models (LLMs) have revolutionized the way we interact with text, enabling capabilities like natural language generation, reasoning, and conversation. However, their utility isn’t limited to text alone. Modern multimodal models can also understand and process images. How exactly are images passed into these models? That’s what this post aims to clarify.

This post starts with a recap of how transformer-based LLMs process text inputs. We then transition to how images can be converted into sequences of embeddings (just like tokens in text) and fed into LLMs. We’ll look at Vision Transformers (ViT), show how they encode images, and finally explain how these embeddings can be integrated into decoder-style LLMs for multimodal tasks such as image captioning and visual question answering.

Key Takeaways:

Recap of Transformer Architecture for Text Inputs

Decoder Style LLMs

We first need to have an understanding of the transformer architecture used in decoder style LLMs. Earlier this year I wrote my first blog post with some notes on the transformer architecture. To get the most out of this post, it would be good to have some familiarity with the transformer architecture. We will give a quick reminder of some basic concepts.

Most LLMs you interact with (like GPT-style models) are decoder-only transformers. In a decoder transformer:

We will load one of the SmolLM2 LLM models created by the Hugging Face team. This is not the instruction fine tuned model, but rather the base pre-trained model. This model may not be as well known as some of the other models, but it is a good model to start with since it is really small and easy to run locally.

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
model
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=576, out_features=49152, bias=False)
)

The input to the transformer layers is a sequence of embeddings. In the case of text inputs, the input first gets converted into a sequence of tokens. Then each token is converted into an embedding vector.

Here is the conversion of the input text to tokens ids.

inputs = tokenizer(["The dog jumped over the"], return_tensors="pt")
input_ids = inputs.input_ids
print(inputs)
print(input_ids.shape)
print(input_ids)
{'input_ids': tensor([[  504,  2767, 25437,   690,   260]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}
torch.Size([1, 5])
tensor([[  504,  2767, 25437,   690,   260]])

Each token id has an associated embedding vector. In the case of this SmolLM2 model, the embedding dimension is 576 and there are 49152 tokens in the vocabulary.

embedding_lkp = model.model.embed_tokens
print(embedding_lkp.weight.shape)
torch.Size([49152, 576])

We can get the token embeddings by passing the token ids to the embedding lookup table. Each row of the returned tensor, ignoring the batch dimension, is a vector representation of a token.

embedding_vectors = embedding_lkp(input_ids)
print(embedding_vectors.shape)
print(embedding_vectors)
torch.Size([1, 5, 576])
tensor([[[ 0.1177,  0.0199, -0.0942,  ...,  0.0405,  0.1182,  0.0762],
         [-0.0356,  0.1338,  0.0050,  ...,  0.0996,  0.0791,  0.0791],
         [-0.0093,  0.0122,  0.0197,  ...,  0.0613, -0.1021, -0.0923],
         [-0.0339,  0.0825, -0.1562,  ...,  0.0349,  0.1172, -0.0752],
         [-0.1514,  0.0181, -0.0742,  ...,  0.0430,  0.0986,  0.0664]]],
       grad_fn=<EmbeddingBackward0>)

It is this sequence of embedding vectors that flows through the transformer layers. The input shape to the transformer layers is (batch_size, sequence_length, embedding_dim) and the output shape is (batch_size, sequence_length, hidden_size). You can get the last hidden state by passing the inputs to the model, excluding the final classification head.

last_hidden_state = model.model(**inputs).last_hidden_state
print(last_hidden_state.shape)
last_hidden_state
torch.Size([1, 5, 576])
tensor([[[ 0.3476,  0.7350,  0.1515,  ..., -0.0168,  0.8690,  1.1515],
         [ 0.0334,  0.6300,  0.7636,  ..., -0.6490,  0.0102, -0.2357],
         [-1.0193,  0.9439,  0.1579,  ..., -0.3536, -2.4959,  1.6141],
         [-2.0151, -0.3402, -0.6598,  ...,  1.7252, -1.6691,  1.4883],
         [-0.6080, -0.9785, -0.8922,  ...,  3.4061, -0.1228, -0.6294]]],
       grad_fn=)

Then this final transformer output is passed to the classification head. The classification head is a single linear layer that maps the hidden state to the logits for the next token. The output shape of the classification head is (batch_size, sequence_length, vocab_size).

logits = model.lm_head(last_hidden_state)
assert torch.allclose(logits, model(**inputs).logits)
logits.shape
torch.Size([1, 5, 49152])

Next we convert the logits to probabilities using the softmax function. While this is useful for visualization and inference, during training we typically use the raw logits directly with CrossEntropyLoss for better numerical stability. Note that we get logits (and after softmax, probabilities) for the next token at each position in the sequence. During inference, we typically only care about the last position's values since that's where we'll generate the next token.

probs = F.softmax(logits, dim=-1)
probs.shape
torch.Size([1, 5, 49152])

This next code block shows that at inference time we get the probabilities for the next token at each position in the sequence. It prints the top 5 predictions for each token in the sequence.

K = 5  # Number of top predictions to show
top_probs, top_indices = torch.topk(probs[0], k=K, dim=-1)  # Remove batch dim and get top K

# Convert token indices to actual tokens and print predictions for each position
input_text = tokenizer.decode(input_ids[0])  # Original text
print(f"Original text: {input_text}\n")

for pos in range(len(input_ids[0])):
    token = tokenizer.decode(input_ids[0][pos])
    print(f"After token: '{token}'")
    print(f"Top {K} predicted next tokens:")
    for prob, idx in zip(top_probs[pos], top_indices[pos]):
        predicted_token = tokenizer.decode(idx)
        print(f"  {predicted_token}: {prob:.3f}")
    print()
Original text: The dog jumped over the

After token: 'The'
Top 5 predicted next tokens:
   first: 0.022
   same: 0.015
   most: 0.012
   world: 0.011
   last: 0.006

After token: ' dog'
Top 5 predicted next tokens:
   was: 0.063
   is: 0.062
  's: 0.047
  ’: 0.039
  ,: 0.031

After token: ' jumped'
Top 5 predicted next tokens:
   up: 0.200
   on: 0.135
   into: 0.068
   over: 0.063
   out: 0.062

After token: ' over'
Top 5 predicted next tokens:
   the: 0.793
   a: 0.032
   it: 0.030
   and: 0.017
   him: 0.013

After token: ' the'
Top 5 predicted next tokens:
   fence: 0.408
   wall: 0.029
   top: 0.017
   bridge: 0.017
   table: 0.013

In summary, the input to the transformer layers is a sequence of embeddings, of shape (batch_size, sequence_length, embedding_dim). The transformer layers process this sequence and return a new sequence of hidden states, of shape (batch_size, sequence_length, hidden_size). It is often the case that the hidden size is the same as the embedding dimension, but this is not a requirement. Even if you forget the details of the inner workings of the transformer layers (self attention, etc.), this is a useful mental model to keep in mind. The final classifier layer returns a probability distribution over the next token for each position in the sequence, of shape (batch_size, sequence_length, vocab_size).

Encoder Models

In contrast, encoder-only models like BERT process the entire input sequence at once without causal masking. They often use a special [CLS] token at the start of the sequence, whose final embedding serves as a global representation of the entire input for tasks like classification. Here are some of the key differences between decoder and encoder models:

Let's load a simple encoder model to illustrate some points.

from transformers import AutoModel

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")
model
DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): DistilBertSdpaAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (activation): GELUActivation()
        )
        (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
    )
  )
)
inputs = tokenizer(["The dog jumped over the"], return_tensors="pt")
input_ids = inputs.input_ids
print(input_ids)
print(tokenizer.decode(input_ids[0]))
last_hidden_state = model(**inputs).last_hidden_state
print(last_hidden_state.shape)
tensor([[ 101, 1996, 3899, 5598, 2058, 1996,  102]])
[CLS] the dog jumped over the [SEP]
torch.Size([1, 7, 768])

In the case of encoder models, the output shape is (batch_size, sequence_length, hidden_size). In the case of Bert, the hidden size is 768 and a 768 dimensional vector is returned for each token in the input sequence.

When using the encoder output for other tasks, such as classification, we typically take the [CLS] token embedding, which is the embedding for the first token.

last_hidden_state[:, 0, :].shape  # `[CLS]` token embedding,
torch.Size([1, 768])

I think it's worth elaborating on the importance of the [CLS] token. Why Use the [CLS] token embedding as the final representation?

When processing text with transformer-based models, each input sequence usually begins with a special token, in this case the [CLS] token. This token doesn’t represent a word or phrase from the input but acts as a placeholder for capturing information about the entire sequence. During training, the [CLS] token is specifically optimized for sequence-level tasks like classification. For example, in sentiment analysis, the model learns to encode the overall sentiment of the input sequence into the [CLS] token embedding. As a result, the [CLS] token becomes a rich summary representation of the entire input sequence. Self-attention mechanisms allow the [CLS] token to attend to all other tokens in the sequence. This means it “sees” the entire context of the input. Recall that we typically don't use the masked attention in the encoder model. Through this process:

Using the [CLS] token provides a single, fixed-size vector (e.g., 768 dimensions for BERT) that can directly feed into a classifier or other downstream layers. These concepts are useful to keep in mind when we discuss image encoders later on.

Introducing Images into Transformers and LLMs

Now that we've revisited how transformers handle text inputs, let's explore how images can be incorporated into transformers. Our eventual goal is to understand how to pass images into decoder style LLMs, along side text, to generate text outputs. If you remember that the input to the transformer layers is a sequence of embeddings, then passing in images is no different. We just need to convert the images into a sequence of embeddings suitable for the transformer layers. The fundamental idea: Transformers work on sequences of embeddings. Text tokens are straightforward; they map directly from discrete tokens to embeddings via lookup tables. Images, on the other hand, must be transformed into a sequence of patch embeddings.

Key Insight:

By treating images as a sequence of flattened patches, we can feed them into a transformer architecture—just like we feed tokens into a text transformer.

We’ll first look at how images are handled by Vision Transformers (ViT). Then we’ll explore models like CLIP, which bridge text and image embeddings, and finally see how these image embeddings are integrated into LLMs for multimodal tasks.

Vision Transformers (ViT)

The first architecture we will focus on is transformer-based image encoders. Specifically, we will examine the Vision Transformer (ViT), a model that adapts the transformer architecture from natural language processing to computer vision tasks. The ViT processes images by dividing them into fixed-size patches, embedding these patches as input tokens, and applying a transformer encoder to learn meaningful representations of the input image. Just like the transformer layers process a sequence of token embeddings, the ViT processes a sequence of image patch embeddings, and returns a sequence of hidden states.

Key Steps for ViT:

The First Vision Transformer

The first Vision Transformer was introduced in the paper AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE.

Figure 1 from the ViT paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

We can load such a pre-trained ViT model from Hugging Face.

from PIL import Image

image = Image.open("imgs/underwater.jpg")
image
from transformers import ViTImageProcessor, ViTModel

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
model
ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTSdpaAttention(
          (attention): ViTSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      )
    )
  )
  (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (pooler): ViTPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
inputs = processor(images=image, return_tensors="pt")
inputs.keys()
dict_keys(['pixel_values'])
inputs.pixel_values.shape  # (batch_size, num_channels, height, width)
torch.Size([1, 3, 224, 224])
# Create patch embeddings from image
patch_embeddings = model.embeddings.patch_embeddings(inputs.pixel_values)
patch_embeddings.shape  # [1, 196, 768]
torch.Size([1, 196, 768])
# Get complete input embeddings and pass through encoder manually.
# These embeddings are the patch embeddings plus the positional embeddings.
# As well as the `[CLS]` token embedding.
full_input_embeddings = model.embeddings(inputs.pixel_values)
encoder_outputs = model.encoder(full_input_embeddings)
manual_output = model.layernorm(encoder_outputs.last_hidden_state)

# Get output using full model forward pass
with torch.no_grad():
    model_outputs = model(**inputs)
    full_model_output = model_outputs.last_hidden_state

# Verify shapes match
assert manual_output.shape == full_model_output.shape

# Verify outputs are identical
assert torch.allclose(manual_output, full_model_output, atol=1e-6)
model_outputs.keys()
odict_keys(['last_hidden_state', 'pooler_output'])
model_outputs.last_hidden_state.shape  # (batch_size, sequence_length, hidden_size)
torch.Size([1, 197, 768])
model_outputs.last_hidden_state[:, 0, :].shape  # `[CLS]` token embedding
torch.Size([1, 768])

In the case of the ViT encoder model, followed by a classification task/head, it is this [CLS] token embedding that we will use as the image representation for downstream tasks. Just like in text encoders, the [CLS] token in ViT learns to aggregate information from all the image patches through self-attention. During training, this token's representation is optimized to capture the global features needed for image classification. However, it's important to note that this [CLS] token approach is specific to encoder-based vision transformers used for classification tasks. When we later discuss feeding images into decoder style LLMs for tasks like image captioning or visual question-answering, we'll see a different approach where the sequence of patch embeddings themselves are used directly, without needing a [CLS] token.

We've now seen how Vision Transformers process images in a way that's analogous to how text transformers process words. The image is divided into patches (like words in a sentence), each patch is flattened from a 16x16x3 grid of pixels into a 768-dimensional vector, then transformed through a learned linear projection layer to create patch embeddings (like word embeddings). Positional embeddings are added to maintain spatial information (like position encodings in text). The key insight is that both text and image transformers fundamentally operate on sequences of embeddings - the main difference is just in how we create these embeddings from the raw input. For ViT, it's through patch extraction, flattening, and linear projection; for text, it's through token lookup tables.

CLIP

CLIP (Contrastive Language-Image Pre-training) represents a significant milestone in connecting visual and textual understanding. Unlike the original ViT which focused solely on image classification, CLIP learns to understand the relationship between images and their natural language descriptions. CLIP was created by OpenAI.

CLIP's architecture consists of two encoders working in parallel:

A text encoder (transformer) that:

An image encoder (can be ViT or other CNN architecture but let's focus on ViT) that:

The key innovation is the contrastive learning process:

This aligned semantic space enables powerful capabilities:

Figure 1 from the paper: "Learning Transferable Visual Models From Natural Language Supervision"

from PIL import Image
from transformers import CLIPModel, CLIPProcessor

model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

image = Image.open("imgs/tropical_island.jpg")
image