diff --git a/code/semiconductors.py b/code/semiconductors.py
index b01f1b0ca2a0c50d43eb0744986800b062ce0d45..15666ded922418770396942047affb88e202ff8e 100644
--- a/code/semiconductors.py
+++ b/code/semiconductors.py
@@ -14,6 +14,7 @@ from common import draw_classic_axes
 E_V, E_C, E_F = -1.2, 1.8, .4
 E_D, E_A = E_C - .7, E_V + .5
 m_h, m_e = 1, .5
+sqrt_plus = lambda x: np.sqrt(x * (x >= 0))
 
 
 def plot_dos():
@@ -23,7 +24,6 @@ def plot_dos():
     n_F = 1/(np.exp(2*(E - E_F)) + 1)
     g_e = m_e * sqrt_plus(E - E_C)
     g_h = m_h * sqrt_plus(E_V - E)
-    sqrt_plus = lambda x: np.sqrt(x * (x >= 0))
     ax.plot(E, g_h, label="$g_e$")
     ax.plot(E, g_e, label="$g_h$")
     ax.plot(E, 10 * g_h * (1-n_F), label="$n_h$")