diff --git a/src/3_drude_model.md b/src/3_drude_model.md
index fccf02cab6910a85411e9a115edab0aee5c057e9..c644b296d79ac36bdef73991bea26ca0be393cc8 100644
--- a/src/3_drude_model.md
+++ b/src/3_drude_model.md
@@ -1,6 +1,7 @@
 ```python tags=["initialize"]
 from matplotlib import pyplot
-
+import matplotlib.animation as animation
+from IPython.display import HTML
 import numpy as np
 
 from common import draw_classic_axes, configure_plotting
@@ -61,21 +62,21 @@ Even under these simplistic assumptions, the trajectory of the electrons is hard
 Due to the random scattering, each trajectory is different, and this is how several example trajectories look:
 
 ```python
-%matplotlib inline
-import matplotlib.pyplot as plt
-import numpy as np
-import matplotlib.animation as animation
-from IPython.display import HTML
+# Use colors from the default color cycle
+default_colors = pyplot.rcParams['axes.prop_cycle'].by_key()['color']
+blue, red = default_colors[0], default_colors[3]
 
 walkers = 20 # number of particles
 tau = 1 # relaxation time
 gamma = .3 # dissipation strength
 a = 1 # acceleration
 dt = .1 # infinitesimal
-T = 20 # simulation time
+T = 10 # simulation time
 
 v = np.zeros((2, int(T // dt), walkers), dtype=float) #
 
+# Select random time steps and scattering angles
+np.random.seed(1)
 scattering_events = np.random.binomial(1, dt/tau, size=v.shape[1:])
 angles = np.random.uniform(high=2*np.pi, size=scattering_events.shape) * scattering_events
 rotations = np.array(
@@ -97,47 +98,33 @@ r = np.cumsum(v * dt, axis=1)
 scattering_positions = np.copy(r)
 scattering_positions[:, ~scattering_events.astype(bool)] = np.nan
 
-fig = plt.figure()
-
-scatter_pts = scattering_positions[:, :100]
-trace = r[:, :100]
+scatter_pts = scattering_positions
 
-nz_scatters = tuple((np.hstack(scatter_pts[0])[~np.isnan(np.hstack(scatter_pts[0]))],
-                    np.hstack(scatter_pts[1])[~np.isnan(np.hstack(scatter_pts[1]))]))
+r_min = np.min(r.reshape(2, -1), axis=1) - 1
+r_max = np.max(r.reshape(2, -1), axis=1) + 1
 
-plt.axis([min(nz_scatters[0])-1,
-          max(nz_scatters[0])+1,
-          min(nz_scatters[1])-1,
-          max(nz_scatters[1])+1])
+fig = pyplot.figure(figsize=(9, 6))
+ax = fig.add_subplot(1, 1, 1)
+ax.axis("off")
+ax.set(xlim=(r_min[0], r_max[0]), ylim=(r_min[1], r_max[1]))
 
-lines = []
-scatterers = []
-for index in range(walkers):
-    lobj = plt.plot([],[], lw=1, color='b', alpha=0.5)[0]
-    lines.append(lobj)
-    scatterers.append(plt.scatter([], [], s=10, c='r'))
+trajectories = ax.plot([],[], lw=1, color=blue, alpha=0.5)[0]
+scatterers = ax.scatter([], [], s=20, c=red)
 
-def animate(i):
-    for lnum, line in enumerate(lines):
-        line.set_data(trace[0][:i, lnum], trace[1][:i, lnum])
-        data = np.stack((scatter_pts[0][:i,lnum], scatter_pts[1][:i, lnum])).T
-        scatterers[lnum].set_offsets(data)
+def frame(i):
+    concatenated_lines = np.concatenate(
+        (r[:, :i], np.nan * np.ones((2, 1, walkers))),
+        axis=1
+    ).transpose(0, 2, 1).reshape(2, -1)
+    trajectories.set_data(*concatenated_lines)
+    scatterers.set_offsets(scatter_pts[:, :i].reshape(2, -1).T)
 
-anim = animation.FuncAnimation(fig, animate, interval=100)
+anim = animation.FuncAnimation(fig, frame, interval=100)
 
-def remove_axes(ax):
-    ax.spines['bottom'].set_color('white')
-    ax.spines['top'].set_color('white') 
-    ax.spines['right'].set_color('white')
-    ax.spines['left'].set_color('white')
-    ax.tick_params(axis='x', colors='white')
-    ax.tick_params(axis='y', colors='white')
-remove_axes(plt.gca());
-plt.close();
+pyplot.close()
 
 HTML(anim.to_html5_video())
 ```
----
 
 ### Equations of motion
 Our goal is finding the *electric current density* $j$.