Skip to content
Snippets Groups Projects
Commit bd885029 authored by Joseph Weston's avatar Joseph Weston
Browse files

remove superfluous return value from 'prepare_for_lapack'

parent 727a181a
No related branches found
No related tags found
No related merge requests found
......@@ -50,5 +50,5 @@ def gen_eig(a, b, left=False, right=True, overwrite_ab=False):
The right eigenvector corresponding to the eigenvalue
``alpha[i]/beta[i]`` is the column ``vr[:,i]``.
"""
ltype, a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
return lapack.ggev(a, b, left, right)
......@@ -45,7 +45,7 @@ def lu_factor(a, overwrite_a=False):
singular : boolean
Whether the matrix a is singular (up to machine precision)
"""
ltype, a = lapack.prepare_for_lapack(overwrite_a, a)
a = lapack.prepare_for_lapack(overwrite_a, a)
return lapack.getrf(a)
......@@ -71,7 +71,7 @@ def lu_solve(matrix_factorization, b):
"a singular matrix. Result of solve step "
"are probably unreliable")
ltype, lu, b = lapack.prepare_for_lapack(False, lu, b)
lu, b = lapack.prepare_for_lapack(False, lu, b)
ipiv = np.ascontiguousarray(np.asanyarray(ipiv), dtype=lapack.int_dtype)
return lapack.getrs(lu, ipiv, b)
......@@ -102,5 +102,5 @@ def rcond_from_lu(matrix_factorization, norm_a, norm="1"):
"""
(lu, ipiv, singular) = matrix_factorization
norm = norm.encode('utf8') # lapack expects bytes
ltype, lu = lapack.prepare_for_lapack(False, lu)
lu = lapack.prepare_for_lapack(False, lu)
return lapack.gecon(lu, norm_a, norm)
......@@ -62,7 +62,7 @@ def schur(a, calc_q=True, calc_ev=True, overwrite_a=False):
LinAlgError
If the underlying QR iteration fails to converge.
"""
ltype, a = lapack.prepare_for_lapack(overwrite_a, a)
a = lapack.prepare_for_lapack(overwrite_a, a)
return lapack.gees(a, calc_q, calc_ev)
......@@ -182,7 +182,7 @@ def order_schur(select, t, q, calc_ev=True, overwrite_tq=False):
``calc_ev == True``
"""
ltype, t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
# Figure out if select is a function or array.
isfun = isarray = True
......@@ -255,7 +255,7 @@ def evecs_from_schur(t, q, select=None, left=False, right=True,
``right == True``.
"""
ltype, t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
# check if select is a function or an array
if select is not None:
......@@ -346,7 +346,7 @@ def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True,
LinAlError
If the underlying QZ iteration fails to converge.
"""
ltype, a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
return lapack.gges(a, b, calc_q, calc_z, calc_ev)
......@@ -408,7 +408,7 @@ def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True,
LinAlError
If the problem is too ill-conditioned.
"""
ltype, s, t, q, z = lapack.prepare_for_lapack(overwrite_stqz, s, t, q, z)
s, t, q, z = lapack.prepare_for_lapack(overwrite_stqz, s, t, q, z)
# Figure out if select is a function or array.
......@@ -491,7 +491,7 @@ def convert_r2c_gen_schur(s, t, q=None, z=None):
If it fails to convert a 2x2 block into complex form (unlikely).
"""
ltype, s, t, q, z = lapack.prepare_for_lapack(True, s, t, q, z)
s, t, q, z = lapack.prepare_for_lapack(True, s, t, q, z)
# Note: overwrite=True does not mean much here, the arrays are all copied
if (s.ndim != 2 or t.ndim != 2 or
......@@ -611,7 +611,7 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
"""
ltype, s, t, q, z = lapack.prepare_for_lapack(overwrite_qz, s, t, q, z)
s, t, q, z = lapack.prepare_for_lapack(overwrite_qz, s, t, q, z)
if left and q is None:
raise ValueError("Matrix q must be provided for left eigenvectors")
......
......@@ -1274,8 +1274,7 @@ def prepare_for_lapack(overwrite, *args):
If an argument is ``None``, it is just passed through and not used to
determine the proper LAPACK type.
`prepare_for_lapack` returns a character indicating the proper LAPACK data
type ('s', 'd', 'c', 'z') and a list of properly converted arrays.
Returns a list of properly converted arrays.
"""
# Make sure we have NumPy arrays
......@@ -1296,18 +1295,10 @@ def prepare_for_lapack(overwrite, *args):
# kind.
dtype = np.common_type(*[arr for arr, ovwrt in mats if arr is not None])
if dtype == np.float32:
lapacktype = 's'
elif dtype == np.float64:
lapacktype = 'd'
elif dtype == np.complex64:
lapacktype = 'c'
elif dtype == np.complex128:
lapacktype = 'z'
else:
if dtype not in (np.float32, np.float64, np.complex64, np.complex128):
raise AssertionError("Unexpected data type from common_type")
ret = [ lapacktype ]
ret = []
for npmat, ovwrt in mats:
# Now make sure that the array is contiguous, and copy if necessary.
if npmat is not None:
......@@ -1332,4 +1323,7 @@ def prepare_for_lapack(overwrite, *args):
ret.append(npmat)
return tuple(ret)
if len(ret) == 1:
return ret[0]
else:
return tuple(ret)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment