How does GPT-4o work?

Based purely on speculation, I have no inside information. All numbers were pulled out of my ass. This is just my highly-detailed, educated guess on how the model works, along with a Pytorch implementation. If that sounds interesting, then keep reading.

Tokenization

The cornerstone of GPT-4o’s capabilities lies in its unified representation of diverse input types. Here’s how each modality is tokenized:

Text Tokenization

If you’re familiar with LLMs, you can skip this section. GPT-4o employs SentencePiece with BPE (Byte Pair Encoding) for text tokenization. This combination provides a robust, language-independent subword tokenization method that’s particularly effective for handling multiple languages and out-of-vocabulary words. Let’s delve deeper into how this works:

Tokenization Using SentencePiece

Suppose we want to train a SentencePiece on the following corpus:

lower lower lower

Step 1: Initial Vocabulary

SentencePiece starts with an initial vocabulary consisting of all individual characters from the input text, including a special symbol to denote whitespace. For example:

Initial vocabulary: {'l', 'o', 'w', 'e', 'r', '▁'}

Step 2: Iterative Pair Merging (Byte Pair Encoding)

The model then identifies the most frequent contiguous pairs of characters in the text and merges them into subword units. This process is repeated iteratively until the desired vocabulary size is reached. Let’s go through the process in more detail:

  1. First Iteration:

    • Input text: lower lower lower
    • The most frequent pair of characters: l and o
    • Merge result: lo
    • New vocabulary: {'lo', 'w', 'e', 'r', '▁'}
    • Updated text: lo w e r lo w e r lo w e r
  2. Second Iteration:

    • Input text: lo w e r lo w e r lo w e r
    • The most frequent pair of characters: lo and w
    • Merge result: low
    • New vocabulary: {'low', 'e', 'r', '▁'}
    • Updated text: low e r low e r low e r

The process continues iteratively until the predefined vocabulary size is reached. The final vocabulary might include common subwords or entire words, depending on their frequency in the training data.

Audio Tokenization

Audio tokenization in GPT-4o is particularly interesting. This process is inspired by recent advancements in neural audio compression, particularly the EnCodec model.

  • Input Format: Raw waveform audio at 24 kHz sampling rate.
  • Preprocessing:
    1. Resampling to 24 kHz if necessary.
    2. Normalization to the range [-1, 1].
    3. Splitting into 250 ms frames (6000 samples per frame).

Encoder Architecture

The audio encoder in GPT-4o consists of:

  1. An initial 1D convolution layer with 32 channels and a kernel size of 7.
  2. Four convolutional blocks, each comprising:
    • A residual unit with two convolutions (kernel size 3) and a skip-connection
    • A down-sampling layer (strided convolution) with strides (2, 4, 5, 8)
    • The number of channels doubles after each down-sampling operation
  3. A sequence modeling component (two-layer LSTM) for capturing temporal dependencies.
  4. A final convolutional layer with a kernel size of 7 to produce the latent representation.

The encoder uses ELU (Exponential Linear Unit) as the activation function and employs Weight Normalization for streamable processing.

Quantization

The encoder output is quantized using Residual Vector Quantization (RVQ), a powerful technique that allows for efficient compression of high-dimensional vectors. In the context of GPT-4o’s audio processing:

  1. Each input frame vector is quantized using a first codebook of size 65,536.
  2. The residual (difference between the input and the quantized output) is then quantized using the next codebook.
  3. This process is repeated for a predefined number of codebooks, in this case, 4.

Token Structure

The resulting audio tokens in GPT-4o have a unique structure:

  • Dimension: 4 x 65536 (4 codebooks x codebook size)
  • These are flattened into a single vector before entering the embedding layer.

The codebook size and number were selected because they are approximately the same size as the text and image vocabulary when flattened and provide enough quality for good speech audio.

Image Tokenization

The image tokenization process in GPT-4o is a crucial component that bridges the gap between raw images and the discrete tokens that the transformer model can process. This process is similar to the approach described in the Chameleon paper but includes some modifications and improvements.

  1. Tokenizer Architecture: The image tokenizer is based on a Vector-Quantized Variational Autoencoder (VQ-VAE) architecture. This tokenizer is trained to encode a 768 × 768 image into 16,384 discrete tokens, each selected from a codebook of size 65,536. This is a significant increase from the 8,192 codebook size used in Chameleon, allowing for more detailed image representations and is particurlarly important for OCR, which Chameleon struggles with.

  2. Encoding Process:

    1. The input image is passed through an encoder network, which transforms it into a 3D tensor of latent vectors.
    2. Each spatial location in this latent space is then quantized to its nearest neighbor in the learned codebook.
    3. The indices of these nearest neighbors become the discrete tokens representing the image.
  3. Vector Quantization: The core of the tokenization process is the vector quantization step. Here’s how it works:

    • For each latent vector z_e(x) produced by the encoder:
      1. Find the closest embedding e_i in the codebook.
      2. The index i becomes the token for this spatial location.
  4. Integration with Transformer: These image tokens are then embedded into the same dimensional space as the text and scene tokens. This allows the transformer to process all modalities in a unified manner.

  5. Decoding: During generation, the transformer predicts image tokens, which are then passed through the VQ-VAE’s decoder to reconstruct the final image.

Source Flag

A binary flag is added to the end of tokens, indicating whether the token was produced by the model or received as input. This is implemented as an additional 0 or 1 at the end of the one-hot encoding for tokens. This crucial feature enables the model to distinguish between its outputs and real-time inputs during processing.

Positional Encodings

This is the real secret sauce of GPT-4o. It has the ability to handle all of these modalities in real time. This is particularly challenging due to the varying rates at which different types of input are received:

  • Audio frames: 4 per second
  • Images: Dependent on frame rate
  • Text: Varies based on user input speed

To address this, GPT-4o employs a dual encoding system similar to σ-GPTs:

  1. Modality-specific position encoding: Represents the position within a specific modality
  2. Time-based encoding: A function of absolute time

These encodings are concatenated to the end of the embedding rather than added, preserving their individual information and preventing them from interfering with one another. This goes against typical transformer architecture design, but is vital to allow the temporal and modality-specific positon embeddings to operately separately. During training and inference, the attention mask considers both the position within a modality and the temporal position, ensuring proper alignment of multimodal inputs.

The time-based embeddings are at the same frequency as the audio, meaning 4 steps per second. This temporal resolution also applies to the image input stream, limiting the model to a maximum of 4 frames per second for video input. However, you can somewhat work around this by giving multiple frames the same time encoding and varying the modality-specific encoding. In practice, you probably wouldn’t want to do this, as you’d start to be limited by compute and would gain minimal information from the extra frames.

Model Inputs

The model has five input streams, which are all concatenated together with attention masks to enforce temporal boundaries:

  1. Text input stream
  2. Audio input stream
  3. Image input stream from the real-time video feed
  4. Text/image output stream
  5. Audio output stream

Note that these are not hard boundaries, and tokens from different modalities can be mixed into the same sequence. For example:

  • “Describe what is happening in the following image: [image]”
  • “Transcribe the following audio: [audio]”

These streams are more akin to logical separators for the real-time inputs. The last two streams are for the model’s outputs, one for text/image and one for audio. This is where the source flag from earlier comes into play, allowing the model to distinguish between its own outputs and the inputs. It needs this because it cannot do autoregressive prediction for real-time inputs, as there will be additional tokens being added to the sequence while it is generating an output.

Model Outputs

GPT-4o features a dual-headed output system:

  1. Text and image output head:

    • Predicts text and image tokens using standard one-hot encoding
    • Includes a “No Output” token
  2. Audio output head:

    • Multiple sub-heads for each codebook in the audio tokenization
    • Also includes a “No Output” token

The “No Output” tokens are not passed back into the input and allow the model to:

  • Choose which modality to reply with
  • Stay silent without consuming extra tokens
  • Respond with both text, image, and audio simultaneously

This design gives GPT-4o the flexibility to generate appropriate responses across modalities.

Model Size

230 billion parameters. I made this number up, but here’s a plausible justification:

  • Llama 3.1 70B costs around $0.5 per million input tokens.
  • GPT-4o-2024-08-06 costs $2.50 per million input tokens.
  • Assume ClosedAI has pricing power due to the model’s intelligence and brand, so they can charge around 50% more.

Calculation: 70 * 2.5 / 0.5 / 1.5 ≈ 230

This size is also approximately equivalent to one GPT-4 expert, which was rumored to be 8x222B. This could explain why the cost is so much lower compared to GPT-4.

This is all just speculation, and I have no affiliation with ClosedAI.

Training

The training process for GPT-4o is complex and multi-staged, incorporating several advanced techniques to achieve its impressive multimodal and real-time capabilities.

Pretraining

During this stage, the temporal embeddings are set to the same value for all text and image samples and are equal to the position embedding for the audio samples. The source flag is also set to 0 for all tokens. Note that the audio and image tokenizers must be trained separately before beginning.

  1. Interleaved Text and Image Pretraining: Following an approach similar to the Chameleon model, GPT-4o is first pretrained on a large dataset of interleaved text and image data scraped from the web. This allows the model to develop a unified representation across these modalities.

  2. Audio Tokenizer Training: After the initial pretraining, the audio embeddings are trained while the main model weights are frozen. The model performs autoregressive prediction on audio tokens, learning to understand and generate audio content. Data for this is sourced by scraping online podcasts and applying voice activity detection to filter the audio.

  3. Multimodal Unfrozen Training: Once the audio tokenizer is sufficiently trained, all weights are unfrozen, and the model trains on all modalities simultaneously, further integrating its understanding across text, image, and audio. Note that audio samples still haven’t been combined with the text and images at this stage, this will happen later. The output head that isn’t used for a sample is trained to predict the “No Output” token.

Fine-tuning and Optimization

  1. Instruction Fine-tuning: The model undergoes instruction fine-tuning to improve its ability to follow user instructions and perform specific tasks. The source flag is now set to 1 for tokens produced by the model.

  2. Direct Preference Optimization (DPO): Utilizing data from user sessions on the ChatGPT website, GPT-4o employs Direct Preference Optimization to align its outputs with user preferences. This method allows for efficient optimization without the need for complex reinforcement learning algorithms, which are difficult to train and prone to over-fitting. The DPO approach used in GPT-4o enables the model to learn from human preferences by directly optimizing a policy to satisfy these preferences, using a simple binary cross-entropy objective.

  3. Chain-of-Thought Training: GPT-4o employs a technique similar to the stepwise internalization method described in recent research. This process helps the model internalize intermediate reasoning steps:

    • The model is initially trained to generate explicit chain-of-thought reasoning steps.
    • Gradually, these intermediate steps are removed during training, forcing the model to internalize the reasoning process.
    • The training progresses through multiple stages, each removing more of the explicit reasoning steps.
    • By the final stage, the model can perform implicit chain-of-thought reasoning, producing high-quality outputs without generating explicit intermediate steps.

    This technique allows GPT-4o to reason effectively while maintaining the speed advantages of models that don’t use explicit chain-of-thought.

Real-time Multimodal Integration

Until now, all of the training has following the standard autoregressive setup. The final stage of training splits the input into the five channels focuses on integrating the various modalities in real-time:

  1. Synthetic Multimodal Data Generation: The base model generates scripts that simulate real-time interactions, including dialogue between users and the assistant, along with actions like “Begin transcribing what I’m saying” or “Describe what is happening in these images”. This leverages the model’s text and image generation capabilities to create scripts similar to what you might see in a TV show, along with supplemental images. Portions of the scripts are then converted to audio using an existing text-to-speech model.

  2. Real-time Processing Simulation: The synthetic data is processed to simulate real-time inputs, with text, audio, and image inputs interleaved to mimic real-world scenarios. A system prompt is added with audio tokens from the target speaker to select a voice for the model.

  3. Final Integrated Training: GPT-4o undergoes a final round of training on this synthetic real-time multimodal data, enhancing its ability to seamlessly integrate and respond to text, image, and audio inputs as they arrive, mimicking real-time interaction scenarios.

Pytorch Implementation

Here’s a PyTorch implementation, written by Claude 3.5 Sonnet:

import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizers import SentencePieceBPETokenizer

class TransformerBlock(nn.Module):
    """
    A single transformer block, consisting of multi-head attention and a feed-forward network.
    """
    def __init__(self, d_model, nhead):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )

    def forward(self, x, attention_mask):
        # Multi-head attention
        attended, _ = self.attention(x, x, x, attn_mask=attention_mask)
        # Add & Norm
        x = self.norm1(x + attended)
        # Feed-forward network
        fed_forward = self.feed_forward(x)
        # Add & Norm
        return self.norm2(x + fed_forward)

class VQVAEEncoder(nn.Module):
    """
    Vector Quantized Variational Autoencoder (VQ-VAE) Encoder for image tokenization.
    """
    def __init__(self, in_channels, hidden_dim, num_embeddings=65536, embedding_dim=32):
        super().__init__()
        # Encoder network to transform 768x768 image into 128x128 latent space
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, embedding_dim, kernel_size=3, stride=1, padding=1)
        )
        # Codebook for vector quantization
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        
    def forward(self, x):
        # Encode the input image
        z = self.encoder(x)
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, z.shape[-1])
        
        # Compute distances to codebook vectors
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.codebook.weight ** 2, dim=1) - \
            2 * torch.matmul(z_flattened, self.codebook.weight.t())
        
        # Find nearest codebook vectors
        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.codebook(min_encoding_indices).view(z.shape)
        
        return min_encoding_indices, z_q

class VQVAEDecoder(nn.Module):
    """
    VQ-VAE Decoder for image reconstruction from tokens.
    """
    def __init__(self, out_channels, hidden_dim, embedding_dim):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, hidden_dim, 3, 1, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, hidden_dim, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, hidden_dim, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, out_channels, 4, 2, 1)
        )
        
    def forward(self, z_q):
        return self.decoder(z_q.permute(0, 3, 1, 2))

class EncoderBlock(nn.Module):
    """
    Encoder block for the audio tokenizer, featuring residual connections and downsampling.
    """
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.residual = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ELU(),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        )
        self.downsample = nn.Conv1d(in_channels, out_channels, kernel_size=stride*2, stride=stride, padding=stride//2)
        self.elu = nn.ELU()
        self.norm = nn.utils.weight_norm(nn.Conv1d(out_channels, out_channels, kernel_size=1))

    def forward(self, x):
        residual = self.residual(x)
        shortcut = self.downsample(x)
        return self.elu(self.norm(residual + shortcut))

class AudioTokenizer(nn.Module):
    """
    Audio tokenizer using a series of convolutional layers, LSTM, and Residual Vector Quantization (RVQ).
    """
    def __init__(self, input_dim, hidden_dim, num_codebooks, codebook_size):
        super().__init__()
        self.num_codebooks = num_codebooks
        self.codebook_size = codebook_size
        
        # Initial convolution
        self.initial_conv = nn.Conv1d(input_dim, hidden_dim, kernel_size=7, padding=3)
        
        # Encoder blocks with different strides
        self.encoder_blocks = nn.ModuleList([
            EncoderBlock(hidden_dim * 2**min(i, 3), hidden_dim * 2**min(i+1, 3), stride)
            for i, stride in enumerate([2, 4, 5, 8])
        ])
        
        # LSTM for sequence modeling
        self.lstm = nn.LSTM(hidden_dim * 8, hidden_dim * 8, num_layers=2, batch_first=True)
        
        # Final convolution
        self.final_conv = nn.Conv1d(hidden_dim * 8, num_codebooks * codebook_size, kernel_size=7, padding=3)
        
        # RVQ codebooks
        self.codebooks = nn.ModuleList([nn.Embedding(codebook_size, hidden_dim * 8) for _ in range(num_codebooks)])

    def forward(self, x):
        # Initial convolution and activation
        x = F.elu(self.initial_conv(x))

        # Pass through encoder blocks
        for block in self.encoder_blocks:
            x = block(x)

        # LSTM processing
        x = x.transpose(1, 2)
        x, _ = self.lstm(x)
        x = x.transpose(1, 2)

        # Final convolution and reshaping
        x = self.final_conv(x)
        x = x.transpose(1, 2).contiguous()
        x = x.view(-1, self.num_codebooks, self.codebook_size)
        
        # Residual Vector Quantization
        indices = []
        quantized = []
        residual = x
        
        for i in range(self.num_codebooks):
            # Find nearest codebook vector
            idx = torch.argmin(torch.sum((residual.unsqueeze(2) - self.codebooks[i].weight.unsqueeze(0).unsqueeze(0)) ** 2, dim=-1), dim=-1)
            quant = self.codebooks[i](idx)
            indices.append(idx)
            quantized.append(quant)
            residual = residual - quant
        
        indices = torch.stack(indices, dim=-1)
        quantized = torch.stack(quantized, dim=-1).sum(dim=-1)
        
        return indices, quantized

class GPT4o(nn.Module):
    class GPT4o(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, num_codebooks, audio_vocab_size, image_vocab_size):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.audio_vocab_size = audio_vocab_size
        self.image_vocab_size = image_vocab_size
        self.num_codebooks = num_codebooks

        # Text tokenizer
        self.text_tokenizer = SentencePieceBPETokenizer()
        # Note: In practice, you would train this tokenizer on your corpus separately
        # and save/load it, rather than training it here.
        # self.text_tokenizer.train(files=["path/to/your/corpus.txt"], vocab_size=vocab_size)

        # Embeddings for different modalities
        self.text_embedding = nn.Embedding(vocab_size, self.d_model)
        self.audio_embedding = nn.Embedding(audio_vocab_size * num_codebooks, self.d_model)
        self.image_embedding = nn.Embedding(image_vocab_size, self.d_model)

        # Audio tokenizer
        self.audio_tokenizer = AudioTokenizer(1, 32, num_codebooks, audio_vocab_size)

        # Image encoder (VQ-VAE)
        self.image_encoder = VQVAEEncoder(3, 256, num_embeddings=image_vocab_size, embedding_dim=32)
        self.image_decoder = VQVAEDecoder(3, 256, 32)

        # Positional and temporal encodings
        self.pos_encoding = nn.Embedding(5000, self.d_model)  # Max 5000 positions
        self.time_encoding = nn.Embedding(1000, self.d_model)  # Max 1000 time steps (250 seconds at 4 steps/sec)

        # Main transformer blocks
        self.blocks = nn.ModuleList([TransformerBlock(self.d_model, nhead) for _ in range(num_layers)])

        # Output heads for text/image and audio
        self.text_image_output = nn.Linear(self.d_model, vocab_size + image_vocab_size + 1)  # +1 for "No Output" token
        self.audio_output = nn.ModuleList([
            nn.Linear(self.d_model, audio_vocab_size + 1) for _ in range(num_codebooks)
        ])

    def tokenize_text(self, text):
        # Tokenize the input text
        encoded = self.text_tokenizer.encode(text)
        return torch.tensor(encoded.ids)

    def forward(self, text_input, audio_input, image_input, positions, times, flags, attention_mask):
        # Ensure all inputs are on the same device
        device = self.text_embedding.weight.device
        
        # Tokenize and embed text inputs
        text_tokens = [self.tokenize_text(text).to(device) for text in text_input]
        text_tokens = nn.utils.rnn.pad_sequence(text_tokens, batch_first=True)
        text_emb = self.text_embedding(text_tokens)
        
        # Tokenize and embed audio inputs
        audio_input = audio_input.to(device)
        audio_tokens, audio_quantized = self.audio_tokenizer(audio_input)
        audio_emb = self.audio_embedding(audio_tokens.view(-1, self.num_codebooks * self.audio_vocab_size))
        
        # Tokenize and embed image inputs
        image_embs = []
        for img in image_input:
            img = img.to(device)
            image_tokens, _ = self.image_encoder(img)
            image_emb = self.image_embedding(image_tokens.view(-1, 16384))  # 16384 tokens for 128x128 latent space
            image_embs.append(image_emb)
        image_emb = torch.cat(image_embs, dim=1)

        # Ensure all embeddings have the same sequence length
        max_len = max(text_emb.shape[1], audio_emb.shape[1], image_emb.shape[1])
        text_emb = F.pad(text_emb, (0, 0, 0, max_len - text_emb.shape[1]))
        audio_emb = F.pad(audio_emb, (0, 0, 0, max_len - audio_emb.shape[1]))
        image_emb = F.pad(image_emb, (0, 0, 0, max_len - image_emb.shape[1]))

        # Combine embeddings from all modalities
        combined_emb = torch.cat([text_emb, audio_emb, image_emb], dim=1)

        # Add positional and temporal encodings
        pos_emb = self.pos_encoding(positions)
        time_emb = self.time_encoding(times)
        combined_emb = torch.cat([combined_emb, pos_emb, time_emb], dim=-1)

        # Add input/output flag to each token embedding
        flags = flags.unsqueeze(-1).float()  # Shape: [batch_size, seq_len, 1]
        combined_emb = torch.cat([combined_emb, flags], dim=-1)

        # Pass through transformer blocks
        for block in self.blocks:
            combined_emb = block(combined_emb, attention_mask)

        # Generate outputs for text/image and audio
        text_image_logits = self.text_image_output(combined_emb)
        audio_logits = [head(combined_emb) for head in self.audio_output]

        return text_image_logits, audio_logits

    def generate_image(self, image_tokens):
        # Convert image tokens back to embeddings
        z_q = self.image_embedding(image_tokens).view(-1, 128, 128, 32)  # 32 is the embedding_dim from VQVAEEncoder
        # Decode the image
        return self.image_decoder(z_q)

    def decode_text(self, text_tokens):
        # Convert text tokens back to string
        return self.text_tokenizer.decode(text_tokens.tolist())

    def decode_audio(self, audio_tokens):
        # This is a placeholder. In practice, you'd need to implement
        # a method to convert audio tokens back to waveform.
        return audio_tokens

# Example usage
vocab_size = 200019
d_model = 4096
nhead = 32
num_layers = 64
num_codebooks = 4
audio_vocab_size = 65536
image_vocab_size = 65536

# Initialize the GPT-4o model
model = GPT4o(vocab_size, d_model, nhead, num_layers, num_codebooks, audio_vocab_size, image_vocab_size)

# Create dummy inputs for demonstration
batch_size = 4
seq_len = 100
audio_frame_length = 6000  # 250 ms at 24 kHz
num_audio_frames = 4  # 4 frames per second
num_images = 1  # Assuming one image per sequence, adjust as needed

# Text input (list of strings)
text_input = ["This is a sample text input" for _ in range(batch_size)]

# Audio input (tensor)
audio_input = torch.randn(batch_size, 1, audio_frame_length * num_audio_frames)  # 1 second of audio

# Image input (tensor)
image_input = torch.randn(batch_size, num_images, 3, 768, 768)  # Batch of RGB images

# Positions, times, and flags
positions = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
times = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
flags = torch.randint(0, 2, (batch_size, seq_len))  # 0 for input, 1 for output

# Attention mask
attention_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).expand(batch_size, -1, -1)

# Forward pass through the model
text_image_logits, audio_logits = model(text_input, audio_input, image_input, positions, times, flags, attention_mask)

# Print output shapes
print("Text and Image logits shape:", text_image_logits.shape)
print("Audio logits shape:", [logits.shape for logits in audio_logits])

# Example of using generate_image
image_tokens = torch.randint(0, image_vocab_size, (batch_size, 16384))
generated_image = model.generate_image(image_tokens)
print("Generated image shape:", generated_image.shape)

# Example of decoding text
text_tokens = torch.argmax(text_image_logits[:, :, :vocab_size], dim=-1)
decoded_text = model.decode_text(text_tokens[0])  # Decode the first sequence in the batch
print("Decoded text:", decoded_text)

# Example of decoding audio
audio_tokens = [torch.argmax(logits, dim=-1) for logits in audio_logits]
decoded_audio = model.decode_audio(audio_tokens)
print("Decoded audio shape:", [tokens.shape for tokens in decoded_audio])

If you’ve made it this far and somehow have access to the ungodly number of GPUs required to train this monster, hit me up and let’s give ClosedAI a run for their money.