Open this lesson in your favourite AI. It'll walk you through the why, explain the demo, and quiz you on the try-it list.
When you compute a linear layer's output y = W x + b, you're doing one dot product per output neuron — and a matmul is just those dot products bundled together and executed in parallel. Modern GPUs and TPUs are literally matmul machines: 90%+ of LLM training time is spent inside matmul kernels. If you hold 'matmul = batched dot products' in your head, you'll read every PyTorch shape error correctly the first time.
For C = A @ B where A is (M, K) and B is (K, N), the result C has shape (M, N) and C[i, j] = sum_k A[i, k] * B[k, j]. The inner dimension (K) must match and gets "contracted" away.
Below: a 3-language naïve implementation. The naïve version is O(M·N·K) in FLOPs but cache-oblivious — real matmul libraries (cuBLAS, oneMKL) can be 100× faster because they block for cache lines. We'll quantify later. For now, get the shapes right.
A and B in the demo. Will it still compile/compile? If not, which shape assertion fails?@. Measure the speedup ratio.seq=2048, d=4096, count how many FLOPs a single Q = X @ W_q matmul does. (Hint: 2 × M × N × K.)Use these three in order. Each builds on the one before.
Explain matrix multiplication shape rules. Give two pairs of shapes: one that works, one that doesn't, and say why.
Why is matmul the single most-optimized operation in numerical computing? Walk through how a cache-blocked matmul hides memory latency, and what a Tensor Core adds on top.
For GPT-3 175B, estimate what fraction of total FLOPs are matmul vs everything else (attention softmax, layer norm, activations). Use the per-layer FLOP breakdown and explain why matmul dominates.
# main.py — matmul from scratch, then compare to NumPy
import numpy as np
def matmul(A, B):
M, K = len(A), len(A[0])
K2, N = len(B), len(B[0])
assert K == K2, f"inner dims must match: got {K} vs {K2}"
C = [[0.0]*N for _ in range(M)]
for i in range(M):
for j in range(N):
s = 0.0
for k in range(K):
s += A[i][k] * B[k][j]
C[i][j] = s
return C
A = [[1, 2, 3], [4, 5, 6]] # (2, 3)
B = [[7, 8], [9, 10], [11, 12]] # (3, 2)
C = matmul(A, B) # (2, 2)
print(C) # [[58, 64], [139, 154]]
# Compare against NumPy
Cnp = np.array(A) @ np.array(B)
print(Cnp)python3 main.py