Commit 78d33dc2 authored by Artem Pulkin's avatar Artem Pulkin
Browse files

ml_util: fix regularization convention and exclude bias

parent 9258db47
Pipeline #88133 failed with stages
in 60 minutes and 55 seconds
import numericalunits as nu
import torch
from itertools import chain
from functools import partial
from collections import namedtuple
from math import isnan
......@@ -599,7 +598,12 @@ def ln_regularization(order, weight, parameters):
result : torch.tensor
The resulting loss.
"""
return weight * sum(torch.linalg.vector_norm(i, ord=order) ** order for i in parameters)
accumulator = 0
n = 0
for i in parameters:
n += torch.numel(i)
accumulator += torch.linalg.vector_norm(i, ord=order) ** order
return accumulator * weight / n
l1_regularization = partial(ln_regularization, 1)
......@@ -697,20 +701,35 @@ class SimpleClosure:
self.optimizer = optimizer(self.learning_parameters(), **opt_args)
return self.optimizer
def learning_parameters(self):
def learning_parameters(self, exclude=None):
"""
Learning parameters.
Parameters
----------
exclude : {tuple, list, str, None}
A list of parameter names to exclude.
Returns
-------
params : Iterable
Parameters to learn.
"""
return chain(*tuple(i.parameters() for i in self.networks))
if exclude is None:
exclude = ()
if isinstance(exclude, str):
exclude = exclude,
for ntw in self.networks:
for name, p in ntw.named_parameters():
for test in exclude:
if test in name:
break
else:
yield p
def loss(self, dataset=None, save=True):
"""
The loss function (excluding regulariation).
The loss function (excluding regularization).
Parameters
----------
......@@ -733,7 +752,7 @@ class SimpleClosure:
self.last_loss = loss
return loss
def propagate(self, dataset=None):
def propagate(self, dataset=None, reg_exclude="bias"):
"""
Propagates the closure.
......@@ -741,6 +760,8 @@ class SimpleClosure:
----------
dataset : torch.utils.data.Dataset
The dataset to compute loss function of.
reg_exclude : str
Parameter names to exclude from regularization.
Returns
-------
......@@ -753,7 +774,7 @@ class SimpleClosure:
if isnan(loss.item()):
raise RuntimeError(f"Optimizer is not stable with loss={loss}")
if self.regularization is not None:
loss = loss + self.regularization(self.learning_parameters())
loss = loss + self.regularization(self.learning_parameters(exclude=reg_exclude))
loss.backward()
return loss
......
......@@ -635,8 +635,7 @@ class MainTest(TestCase):
for i in potentials:
self.assertIs(i.family, ml_potential_family)
for ll in i.parameters["nn"][::2]:
assert ll.weight.data.max() < 1e-8
assert ll.bias.data.max() < 1e-8
assert ll.weight.data.max() < 1e-5
def test_version(self):
self.assertIsInstance(__version__, str)
......@@ -462,3 +462,11 @@ class WorkflowTest(TestCase):
testing.assert_equal(i.vectors, j.vectors, err_msg=str(x))
testing.assert_allclose(i.coordinates, j.coordinates, err_msg=str(x))
testing.assert_equal(i.values, j.values, err_msg=str(x))
def test_parameter_filtering():
nn = ml_util.behler_nn(10, n_layers=2)
closure = ml_util.SimpleClosure([nn], ml_util.energy_loss)
assert len(list(closure.learning_parameters())) == 4
assert len(list(closure.learning_parameters(exclude="bias"))) == 2
assert len(list(closure.learning_parameters(exclude="weight"))) == 2
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment