Open this lesson in your favourite AI. It'll walk you through the why, explain the demo, and quiz you on the try-it list.
Once you have gradients, 'training' is embarrassingly simple: compute the gradient, subtract a small multiple of it from each parameter, repeat. That's it. Every sophisticated optimizer (Momentum, Adam, AdamW, Lion) is a variation on this core loop. Understanding the vanilla loop means you'll understand AdamW when it shows up in Module 7 as a tweak, not a mystery.
Task: fit the line y = 2x + 3 to 100 noisy points by learning w and b from scratch with gradient descent. No framework, no autograd — we hand-derive the two gradients.
Loss: mean squared error. Per-sample gradient: dL/dw = 2(w*x + b - y) * x, dL/db = 2(w*x + b - y).
Update rule: w -= lr * mean(dL/dw), b -= lr * mean(dL/db). Run for 200 iterations and watch w approach 2 and b approach 3.
lr from 0.01 to 1.0 and re-run. Observe the divergence: the updates overshoot and the loss explodes.lr to 0.0001 and re-run for 200 steps. Does it converge? Increase steps to 10000 and compare.y = 2x^2 + 3 and re-run the same fitter. Plot the final line — it's the best linear fit to a quadratic. Why is this the 'bias' of linear models?Use these three in order. Each builds on the one before.
Define gradient descent in one paragraph. Explain each of the three knobs: learning rate, initialization, number of steps.
Walk me through why we subtract (not add) the gradient, and why we scale by a learning rate. Use a 1D parabola and trace three steps.
Explain why plain SGD is almost never used for LLMs. Compare to Momentum, RMSProp, Adam, and AdamW — what each fixes and what the optimizer state (memory overhead per parameter) looks like.
# main.py — fit y = 2x + 3 with vanilla gradient descent
import random
random.seed(0)
# data
xs = [random.uniform(-5, 5) for _ in range(100)]
ys = [2*x + 3 + random.gauss(0, 0.3) for x in xs]
# parameters — start far from truth
w, b, lr = 0.0, 0.0, 0.01
for step in range(200):
# forward
preds = [w*x + b for x in xs]
loss = sum((p - y)**2 for p, y in zip(preds, ys)) / len(xs)
# backward (hand-derived)
dw = sum(2*(p - y)*x for p, y, x in zip(preds, ys, xs)) / len(xs)
db = sum(2*(p - y) for p, y in zip(preds, ys)) / len(xs)
# update
w -= lr * dw
b -= lr * db
if step % 40 == 0:
print(f"step {step:3d} loss={loss:.4f} w={w:.3f} b={b:.3f}")
print(f"final: w={w:.3f}, b={b:.3f} (truth: w=2, b=3)")python3 main.py