Commit 9258db47 authored by Artem Pulkin's avatar Artem Pulkin
Browse files

ml: fix deprecation

parent 864fb371
Pipeline #87275 failed with stages
in 14 minutes and 9 seconds
...@@ -1504,11 +1504,10 @@ class Normalization: ...@@ -1504,11 +1504,10 @@ class Normalization:
if pad: if pad:
atom_counts_padded = torch.cat((atom_counts, torch.eye(n, dtype=atom_counts.dtype))) atom_counts_padded = torch.cat((atom_counts, torch.eye(n, dtype=atom_counts.dtype)))
energy_padded = torch.cat((energy, torch.zeros((n, 1), dtype=energy.dtype))) energy_padded = torch.cat((energy, torch.zeros((n, 1), dtype=energy.dtype)))
energy_offsets, _ = torch.lstsq(energy_padded, atom_counts_padded) energy_offsets = torch.linalg.lstsq(atom_counts_padded, energy_padded).solution
else: else:
energy_offsets, _ = torch.lstsq(energy, atom_counts) energy_offsets = torch.linalg.lstsq(atom_counts, energy).solution
energy_offsets = energy_offsets[:n]
residuals = energy - atom_counts @ energy_offsets residuals = energy - atom_counts @ energy_offsets
return energy_offsets, residuals return energy_offsets, residuals
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment