Commit a46b11e6 authored by Anton Akhmerov's avatar Anton Akhmerov
Browse files

switch modes to use @ instead of dot

parent a8947abf
......@@ -19,7 +19,6 @@ from scipy.linalg import block_diag
from scipy.sparse import (identity as sp_identity, hstack as sp_hstack,
csr_matrix)
dot = np.dot
__all__ = ['selfenergy', 'modes', 'PropagatingModes', 'StabilizedModes']
......@@ -169,7 +168,7 @@ class StabilizedModes:
outgoing = slice(self.nmodes, None)
vecs = self.vecs[:, outgoing]
vecslmbdainv = self.vecslmbdainv[:, outgoing]
return dot(v, dot(vecs, la.solve(vecslmbdainv, v.T.conj())))
return v @ vecs @ la.solve(vecslmbdainv, v.T.conj())
# Auxiliary functions that perform different parts of the calculation.
......@@ -248,7 +247,7 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
A_inv = la.inv(A)
lhs = np.zeros((2*n, 2*n), dtype=np.common_type(h_cell, h_hop))
lhs[:n, :n] = -dot(A_inv, h_cell) * B_H_inv
lhs[:n, :n] = -(A_inv @ h_cell) * B_H_inv
lhs[:n, n:] = -A_inv * B
lhs[n:, :n] = A.T.conj() * B_H_inv
......@@ -301,7 +300,7 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
need_to_stabilize = True
# Matrices are complex or need self-energy-like term to be
# stabilized.
temp = dot(u, u.T.conj()) + dot(v, v.T.conj())
temp = u @ u.T.conj() + v @ v.T.conj()
h = h_cell + 1j * temp
sol = kla.lu_factor(h)
......@@ -317,10 +316,10 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
# the projected one (v^dagger psi lambda^-1, s u^dagger psi).
def extract_wf(psi, lmbdainv):
wf = -dot(u, psi[: n_nonsing] * lmbdainv) - dot(v, psi[n_nonsing:])
wf = -u @ (psi[: n_nonsing] * lmbdainv) - v @ psi[n_nonsing:]
if need_to_stabilize:
wf += 1j * (dot(v, psi[: n_nonsing]) +
dot(u, psi[n_nonsing:] * lmbdainv))
wf += 1j * (v @ psi[: n_nonsing] +
u @ (psi[n_nonsing:] * lmbdainv))
return kla.lu_solve(sol, wf)
# Setup the generalized eigenvalue problem.
......@@ -332,22 +331,22 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
A[end, begin] = np.identity(n_nonsing)
temp = kla.lu_solve(sol, v)
temp2 = dot(u.T.conj(), temp)
temp2 = u.T.conj() @ temp
if need_to_stabilize:
A[begin, begin] = -1j * temp2
A[begin, end] = temp2
temp2 = dot(v.T.conj(), temp)
temp2 = v.T.conj() @ temp
if need_to_stabilize:
A[end, begin] -= 1j *temp2
A[end, end] = temp2
B[begin, end] = -np.identity(n_nonsing)
temp = kla.lu_solve(sol, u)
temp2 = dot(u.T.conj(), temp)
temp2 = u.T.conj() @ temp
B[begin, begin] = -temp2
if need_to_stabilize:
B[begin, end] += 1j * temp2
temp2 = dot(v.T.conj(), temp)
temp2 = v.T.conj() @ temp
B[end, begin] = -temp2
if need_to_stabilize:
B[end, end] = 1j * temp2
......@@ -481,15 +480,15 @@ def phs_symmetrization(wfs, particle_hole):
def Pdot(mat):
"""Apply the particle-hole operator to an array. """
return particle_hole.dot(mat.conj())
return particle_hole @ mat.conj()
# Take P in the subspace of W = wfs: U = W^+ @ P @ W^*.
U = wfs.T.conj().dot(Pdot(wfs))
U = wfs.T.conj() @ Pdot(wfs)
# Check that wfs are orthonormal and the space spanned
# by them is closed under ph, meaning U is unitary.
if not np.allclose(U.dot(U.T.conj()), np.eye(U.shape[0])):
if not np.allclose(U @ U.T.conj(), np.eye(U.shape[0])):
raise ValueError('wfs are not orthonormal or not closed under particle_hole.')
P_squared = U.dot(U.conj())
P_squared = U @ U.conj()
if np.allclose(P_squared, np.eye(U.shape[0])):
P_squared = 1
elif np.allclose(P_squared, -np.eye(U.shape[0])):
......@@ -515,12 +514,12 @@ def phs_symmetrization(wfs, particle_hole):
shift = -np.pi - (phases[i] + dph[i]/2)
# Take matrix square root with branch cut in largest gap
vals = np.sqrt(vals * np.exp(1j * shift)) * np.exp(-0.5j * shift)
sqrtU = vecs.dot(np.diag(vals)).dot(vecs.T.conj())
sqrtU = vecs @ np.diag(vals) @ vecs.T.conj()
# For symmetric U sqrt(U) is also symmetric.
assert np.allclose(sqrtU, sqrtU.T)
# We want a new basis W_new such that W_new^+ @ P @ W_new^* = 1.
# This is achieved by W_new = W @ sqrt(U).
new_wfs = wfs.dot(sqrtU)
new_wfs = wfs @ sqrtU
# If P^2 = 1, there is no need to sort the modes further.
TRIM_sort = np.zeros((wfs.shape[1],), dtype=int)
else:
......@@ -556,12 +555,11 @@ def phs_symmetrization(wfs, particle_hole):
wfs = wfs[:, 1:]
# Now we project the remaining modes onto the orthogonal
# complement of P psi_n. projector:
projector = wfs.dot(wfs.T.conj()) - \
np.outer(P_wf, P_wf.T.conj())
projector = wfs @ wfs.T.conj() - np.outer(P_wf, P_wf.T.conj())
# After the projection, the mode matrix is rank deficient -
# the span of the column space has dimension one less than
# the number of columns.
wfs = projector.dot(wfs)
wfs = projector @ wfs
wfs = la.qr(wfs, mode='economic', pivoting=True)[0]
# Remove the redundant column.
wfs = wfs[:, :-1]
......@@ -574,14 +572,13 @@ def phs_symmetrization(wfs, particle_hole):
# Store psi_n and P psi_n.
new_wfs.append(wf)
new_wfs.append(Pdot(wf))
assert np.allclose(wfs.T.conj().dot(wfs),
np.eye(wfs.shape[1]))
assert np.allclose(wfs.T.conj() @ wfs, np.eye(wfs.shape[1]))
new_wfs = np.hstack([col.reshape(len(col), 1)/npl.norm(col) for
col in new_wfs])
assert np.allclose(new_wfs[:, 1::2], Pdot(new_wfs[:, ::2]))
# Store sort ordering in this subspace of modes
TRIM_sort = np.arange(new_wfs.shape[1])
assert np.allclose(new_wfs.T.conj().dot(new_wfs), np.eye(new_wfs.shape[1]))
assert np.allclose(new_wfs.T.conj() @ new_wfs, np.eye(new_wfs.shape[1]))
return new_wfs, TRIM_sort
......@@ -674,7 +671,7 @@ def make_proper_modes(lmbdainv, psi, extract, tol, particle_hole,
# must have the same velocity). However, this does not matter,
# since we are happy with any superposition in this case.
vel_op = -1j * dot(psi[n:, indx].T.conj(), psi[:n, indx])
vel_op = -1j * psi[n:, indx].T.conj() @ psi[:n, indx]
vel_op = vel_op + vel_op.T.conj()
vel_vals, rot = la.eigh(vel_op)
......@@ -686,8 +683,8 @@ def make_proper_modes(lmbdainv, psi, extract, tol, particle_hole,
if full_psi.dtype != np.common_type(full_psi, rot):
full_psi = full_psi.astype(np.common_type(psi, rot))
psi[:, indx] = dot(psi[:, indx], rot)
full_psi[:, indx] = dot(full_psi[:, indx], rot)
psi[:, indx] = psi[:, indx] @ rot
full_psi[:, indx] = full_psi[:, indx] @ rot
velocities[indx] = vel_vals
# With particle-hole symmetry, treat TRIMs individually.
......@@ -757,18 +754,18 @@ def make_proper_modes(lmbdainv, psi, extract, tol, particle_hole,
full_psi[:, indx] = new_wf
# For both cases P^2 = +-1, must rotate wave functions in the
# singular value basis. Find the rotation from new basis to old.
rot = new_wf.T.conj().dot(orig_wf)
rot = new_wf.T.conj() @ orig_wf
# Rotate the wave functions in the singular value basis
psi[:, indx] = psi[:, indx].dot(rot.T.conj())
psi[:, indx] = psi[:, indx] @ rot.T.conj()
# Ensure proper usage of chiral symmetry.
if chiral is not None and time_reversal is None:
out_orig = full_psi[:, indx[len(indx)//2:]]
out = chiral.dot(full_psi[:, indx[:len(indx)//2]])
out = chiral @ full_psi[:, indx[:len(indx)//2]]
# No least squares below because the modes should be orthogonal.
rot = out_orig.T.conj().dot(out)
rot = out_orig.T.conj() @ out
full_psi[:, indx[len(indx)//2:]] = out
psi[:, indx[len(indx)//2:]] = psi[:, indx[len(indx)//2:]].dot(rot)
psi[:, indx[len(indx)//2:]] = psi[:, indx[len(indx)//2:]] @ rot
if np.any(abs(velocities) < vel_eps):
raise RuntimeError("Found a mode with zero or close to zero velocity.")
......@@ -806,12 +803,12 @@ def make_proper_modes(lmbdainv, psi, extract, tol, particle_hole,
# is [-k2, -k1, k1, k2], if k2, k1 > 0 and k2 > k1.
# To maintain this ordering with ki and -ki as particle-hole partners,
# reverse the order of the product at the end.
wf_neg_k = particle_hole.dot(
wf_neg_k = (particle_hole @
(full_psi[:, :N][:, positive_k]).conj())[:, ::-1]
rot = la.lstsq(orig_neg_k, wf_neg_k)[0]
full_psi[:, :N][:, positive_k[::-1]] = wf_neg_k
psi[:, :N][:, positive_k[::-1]] = \
psi[:, :N][:, positive_k[::-1]].dot(rot)
psi[:, :N][:, positive_k[::-1]] @ rot
# Outgoing modes
positive_k = (np.pi - eps > momenta[N:]) * (momenta[N:] > eps)
......@@ -821,12 +818,12 @@ def make_proper_modes(lmbdainv, psi, extract, tol, particle_hole,
# is like [k2, k1, -k1, -k2], if k2, k1 > 0 and k2 > k1.
# Reverse order at the end to match momenta of opposite sign.
wf_neg_k = particle_hole.dot(
wf_neg_k = (particle_hole @
full_psi[:, N:][:, positive_k].conj())[:, ::-1]
rot = la.lstsq(orig_neg_k, wf_neg_k)[0]
full_psi[:, N:][:, positive_k[::-1]] = wf_neg_k
psi[:, N:][:, positive_k[::-1]] = \
psi[:, N:][:, positive_k[::-1]].dot(rot)
psi[:, N:][:, positive_k[::-1]] @ rot
# Modes are ordered by velocity.
# Use time-reversal symmetry to relate modes of opposite velocity.
......@@ -834,10 +831,10 @@ def make_proper_modes(lmbdainv, psi, extract, tol, particle_hole,
# Note: within this function, nmodes refers to the total number
# of propagating modes, not either left or right movers.
out_orig = full_psi[:, nmodes//2:]
out = time_reversal.dot(full_psi[:, :nmodes//2].conj())
out = time_reversal @ full_psi[:, :nmodes//2].conj()
rot = la.lstsq(out_orig, out)[0]
full_psi[:, nmodes//2:] = out
psi[:, nmodes//2:] = psi[:, nmodes//2:].dot(rot)
psi[:, nmodes//2:] = psi[:, nmodes//2:] @ rot
norm = np.sqrt(abs(velocities))
full_psi = full_psi / norm
......@@ -951,8 +948,8 @@ def transform_modes(modes_data, unitary=None, time_reversal=None,
if flip_energy != conj:
velocities *= -1
wave_functions = unitary.dot(wave_functions)[:, perm]
v = unitary.dot(v)
wave_functions = (unitary @ wave_functions)[:, perm]
v = unitary @ v
vecs[:, :2*nmodes] = vecs[:, perm]
vecslmbdainv[:, :2*nmodes] = vecslmbdainv[:, perm]
velocities = velocities[perm]
......@@ -1069,13 +1066,7 @@ def modes(h_cell, h_hop, tol=1e6, stabilization=None, *,
def basis_change(a, antiunitary=False):
b = projection_op
# We need the extra transposes to ensure that sparse dot is used.
if antiunitary:
# b.T.conj() @ a @ b.conj()
return (b.T.conj().dot((b.T.conj().dot(a)).T)).T
else:
# b.T.conj() @ a @ b
return (b.T.dot((b.T.conj().dot(a)).T)).T
return b.T.conj() @ a @ (b.conj() if antiunitary else b)
# Conservation law basis
ham_cons = basis_change(ham)
......@@ -1120,7 +1111,7 @@ def modes(h_cell, h_hop, tol=1e6, stabilization=None, *,
vecs, vecslmbdainv, sqrt_hops) = zip(*block_modes)
# Reorder by direction of propagation
wave_functions = group_halves([(projector.dot(wf)).T for wf, projector in
wave_functions = group_halves([(projector @ wf).T for wf, projector in
zip(wave_functions, projectors)]).T
# Propagating modes object to return
prop_modes = PropagatingModes(wave_functions, group_halves(velocities),
......@@ -1140,7 +1131,7 @@ def modes(h_cell, h_hop, tol=1e6, stabilization=None, *,
for n, v in zip(nmodes, vecslmbdainv)))
vecslmbdainv = np.hstack([block_diag(*part) for part in parts])
sqrt_hops = np.hstack([projector.dot(hop) for projector, hop in
sqrt_hops = np.hstack([projector @ hop for projector, hop in
zip(projectors, sqrt_hops)])[:h_hop.shape[1]]
stab_modes = StabilizedModes(vecs, vecslmbdainv, sum(nmodes), sqrt_hops)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment