diff --git a/chainerrl/misc/__init__.py b/chainerrl/misc/__init__.py index 1219bee5e..c3c1266bb 100644 --- a/chainerrl/misc/__init__.py +++ b/chainerrl/misc/__init__.py @@ -4,5 +4,6 @@ from chainerrl.misc.draw_computational_graph import draw_computational_graph # NOQA from chainerrl.misc.draw_computational_graph import is_graphviz_available # NOQA from chainerrl.misc import env_modifiers # NOQA +from chainerrl.misc.namedpersistent import namedpersistent # NOQA from chainerrl.misc.is_return_code_zero import is_return_code_zero # NOQA from chainerrl.misc.random_seed import set_random_seed # NOQA diff --git a/chainerrl/misc/async_.py b/chainerrl/misc/async_.py index 317d20877..605e47bb0 100644 --- a/chainerrl/misc/async_.py +++ b/chainerrl/misc/async_.py @@ -12,6 +12,7 @@ import chainer import numpy as np +import chainerrl from chainerrl.misc import random_seed @@ -32,19 +33,56 @@ def ensure_initialized_update_rule(param): u.init_state(param) +def _set_persistent_values_recursively(link, persistent_name, shared_array): + if persistent_name.startswith('/'): + persistent_name = persistent_name[1:] + if hasattr(link, persistent_name): + attr_name = persistent_name + attr = getattr(link, attr_name) + if isinstance(attr, np.ndarray): + setattr(link, persistent_name, np.frombuffer( + shared_array, dtype=attr.dtype).reshape(attr.shape)) + else: + assert np.isscalar(attr) + # We wrap scalars with np.ndarray because + # multiprocessing.RawValue cannot be used as a scalar, while + # np.ndarray can be. + typecode = np.asarray(attr).dtype.char + setattr(link, attr_name, np.frombuffer( + shared_array, dtype=typecode).reshape(())) + else: + assert isinstance(link, (chainer.Chain, chainer.ChainList)) + assert '/' in persistent_name + child_name, remaining = persistent_name.split('/', 1) + if isinstance(link, chainer.Chain): + _set_persistent_values_recursively( + getattr(link, child_name), remaining, shared_array) + else: + _set_persistent_values_recursively( + link[int(child_name)], remaining, shared_array) + + def set_shared_params(a, b): - """Set shared params to a link. + """Set shared params (and persistent values) to a link. Args: a (chainer.Link): link whose params are to be replaced b (dict): dict that consists of (param_name, multiprocessing.Array) """ assert isinstance(a, chainer.Link) + remaining_keys = set(b.keys()) for param_name, param in a.namedparams(): if param_name in b: shared_param = b[param_name] param.array = np.frombuffer( shared_param, dtype=param.dtype).reshape(param.shape) + remaining_keys.remove(param_name) + for persistent_name, _ in chainerrl.misc.namedpersistent(a): + if persistent_name in b: + _set_persistent_values_recursively( + a, persistent_name, b[persistent_name]) + remaining_keys.remove(persistent_name) + assert not remaining_keys def make_params_not_shared(a): @@ -85,7 +123,22 @@ def extract_params_as_shared_arrays(link): assert isinstance(link, chainer.Link) shared_arrays = {} for param_name, param in link.namedparams(): - shared_arrays[param_name] = mp.RawArray('f', param.array.ravel()) + typecode = param.array.dtype.char + shared_arrays[param_name] = mp.RawArray(typecode, param.array.ravel()) + + for persistent_name, persistent in chainerrl.misc.namedpersistent(link): + if isinstance(persistent, np.ndarray): + typecode = persistent.dtype.char + shared_arrays[persistent_name] = mp.RawArray( + typecode, persistent.ravel()) + else: + assert np.isscalar(persistent) + # Wrap by a 1-dim array because multiprocessing.RawArray does not + # accept a 0-dim array. + persistent_as_array = np.asarray([persistent]) + typecode = persistent_as_array.dtype.char + shared_arrays[persistent_name] = mp.RawArray( + typecode, persistent_as_array) return shared_arrays diff --git a/chainerrl/misc/namedpersistent.py b/chainerrl/misc/namedpersistent.py new file mode 100644 index 000000000..1439a2d26 --- /dev/null +++ b/chainerrl/misc/namedpersistent.py @@ -0,0 +1,40 @@ +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA + +import chainer + + +def _namedchildren(link): + if isinstance(link, chainer.Chain): + for name in sorted(link._children): + yield name, link.__dict__[name] + elif isinstance(link, chainer.ChainList): + for idx, child in enumerate(link._children): + yield str(idx), child + + +def namedpersistent(link): + """Return a generator of all (path, persistent) pairs for a given link. + + This function is adopted from https://github.com/chainer/chainer/pull/6788. + Once it is merged into Chainer, we should use the property instead. + + Args: + link (chainer.Link): Link. + + Returns: + A generator object that generates all (path, persistent) pairs. + The paths are relative from this link. + """ + d = link.__dict__ + for name in sorted(link._persistent): + yield '/' + name, d[name] + for name, child in _namedchildren(link): + prefix = '/' + name + for path, persistent in namedpersistent(child): + yield prefix + path, persistent diff --git a/tests/misc_tests/test_async.py b/tests/misc_tests/test_async.py index 6e52de85d..48fe8d644 100644 --- a/tests/misc_tests/test_async.py +++ b/tests/misc_tests/test_async.py @@ -19,15 +19,56 @@ import copy import numpy as np +import chainerrl from chainerrl.misc import async_ +def _assert_same_pointers_to_persistent_values(a, b): + assert isinstance(a, chainer.Link) + assert isinstance(b, chainer.Link) + a_persistents = dict(chainerrl.misc.namedpersistent(a)) + b_persistents = dict(chainerrl.misc.namedpersistent(b)) + assert set(a_persistents.keys()) == set(b_persistents.keys()) + for key in a_persistents: + a_persistent = a_persistents[key] + b_persistent = b_persistents[key] + assert isinstance(a_persistent, np.ndarray) + assert isinstance(b_persistent, np.ndarray) + assert a_persistent.ctypes.data == b_persistent.ctypes.data + + +def _assert_same_pointers_to_param_data(a, b): + assert isinstance(a, chainer.Link) + assert isinstance(b, chainer.Link) + a_params = dict(a.namedparams()) + b_params = dict(b.namedparams()) + assert set(a_params.keys()) == set(b_params.keys()) + for key in a_params.keys(): + assert isinstance(a_params[key], chainer.Variable) + assert isinstance(b_params[key], chainer.Variable) + assert (a_params[key].array.ctypes.data + == b_params[key].array.ctypes.data) + + +def _assert_different_pointers_to_param_grad(a, b): + assert isinstance(a, chainer.Link) + assert isinstance(b, chainer.Link) + a_params = dict(a.namedparams()) + b_params = dict(b.namedparams()) + assert set(a_params.keys()) == set(b_params.keys()) + for key in a_params.keys(): + assert isinstance(a_params[key], chainer.Variable) + assert isinstance(b_params[key], chainer.Variable) + assert (a_params[key].grad.ctypes.data + != b_params[key].grad.ctypes.data) + + class TestAsync(unittest.TestCase): def setUp(self): pass - def test_share_params(self): + def test_share_params_linear(self): # A's params are shared with B and C so that all the three share the # same parameter arrays @@ -35,6 +76,8 @@ def test_share_params(self): model_a = L.Linear(2, 2) arrays = async_.share_params_as_shared_arrays(model_a) + assert isinstance(arrays, dict) + assert set(arrays.keys()) == {'/W', '/b'} model_b = L.Linear(2, 2) model_c = L.Linear(2, 2) @@ -42,28 +85,94 @@ def test_share_params(self): async_.set_shared_params(model_b, arrays) async_.set_shared_params(model_c, arrays) - a_params = dict(model_a.namedparams()) - b_params = dict(model_b.namedparams()) - c_params = dict(model_c.namedparams()) + # Pointers to parameters must be the same + _assert_same_pointers_to_param_data(model_a, model_b) + _assert_same_pointers_to_param_data(model_a, model_c) + # Pointers to gradients must be different + _assert_different_pointers_to_param_grad(model_a, model_b) + _assert_different_pointers_to_param_grad(model_a, model_c) + _assert_different_pointers_to_param_grad(model_b, model_c) + # Pointers to persistent values must be the same + _assert_same_pointers_to_persistent_values(model_a, model_b) + _assert_same_pointers_to_persistent_values(model_a, model_c) + + def test_share_params_batch_normalization(self): + + # A's params and persistent values are all shared with B and C + + model_a = L.BatchNormalization(3) + + arrays = async_.share_params_as_shared_arrays(model_a) + assert isinstance(arrays, dict) + assert set(arrays.keys()) == { + '/gamma', '/beta', '/avg_mean', '/avg_var', '/N'} - def assert_same_pointers_to_data(a, b): - self.assertEqual(a['/W'].array.ctypes.data, - b['/W'].array.ctypes.data) - self.assertEqual(a['/b'].array.ctypes.data, - b['/b'].array.ctypes.data) + model_b = L.BatchNormalization(3) + model_c = L.BatchNormalization(3) - def assert_different_pointers_to_grad(a, b): - self.assertNotEqual(a['/W'].grad.ctypes.data, - b['/W'].grad.ctypes.data) - self.assertNotEqual(a['/b'].grad.ctypes.data, - b['/b'].grad.ctypes.data) + async_.set_shared_params(model_b, arrays) + async_.set_shared_params(model_c, arrays) + + # Pointers to parameters must be the same + _assert_same_pointers_to_param_data(model_a, model_b) + _assert_same_pointers_to_param_data(model_a, model_c) + # Pointers to gradients must be different + _assert_different_pointers_to_param_grad(model_a, model_b) + _assert_different_pointers_to_param_grad(model_a, model_c) + _assert_different_pointers_to_param_grad(model_b, model_c) + # Pointers to persistent values must be the same + _assert_same_pointers_to_persistent_values(model_a, model_b) + _assert_same_pointers_to_persistent_values(model_a, model_c) + + # Check if N is shared correctly among links + assert model_a.N == 0 + assert model_b.N == 0 + assert model_c.N == 0 + test_input = np.random.normal(size=(2, 3)).astype(np.float32) + model_a(test_input, finetune=True) + assert model_a.N == 1 + assert model_b.N == 1 + assert model_c.N == 1 + model_c(test_input, finetune=True) + assert model_a.N == 2 + assert model_b.N == 2 + assert model_c.N == 2 + + def test_share_params_chain_list(self): + + model_a = chainer.ChainList( + L.BatchNormalization(3), + chainer.ChainList(L.Linear(3, 5)), + ) + + arrays = async_.share_params_as_shared_arrays(model_a) + assert isinstance(arrays, dict) + assert set(arrays.keys()) == { + '/0/gamma', '/0/beta', '/0/avg_mean', '/0/avg_var', '/0/N', + '/1/0/W', '/1/0/b'} + + model_b = chainer.ChainList( + L.BatchNormalization(3), + chainer.ChainList(L.Linear(3, 5)), + ) + model_c = chainer.ChainList( + L.BatchNormalization(3), + chainer.ChainList(L.Linear(3, 5)), + ) + + async_.set_shared_params(model_b, arrays) + async_.set_shared_params(model_c, arrays) # Pointers to parameters must be the same - assert_same_pointers_to_data(a_params, b_params) - assert_same_pointers_to_data(a_params, c_params) + _assert_same_pointers_to_param_data(model_a, model_b) + _assert_same_pointers_to_param_data(model_a, model_c) # Pointers to gradients must be different - assert_different_pointers_to_grad(a_params, b_params) - assert_different_pointers_to_grad(a_params, c_params) + _assert_different_pointers_to_param_grad(model_a, model_b) + _assert_different_pointers_to_param_grad(model_a, model_c) + _assert_different_pointers_to_param_grad(model_b, model_c) + # Pointers to persistent values must be the same + _assert_same_pointers_to_persistent_values(model_a, model_b) + _assert_same_pointers_to_persistent_values(model_a, model_c) def test_share_states(self): @@ -114,10 +223,8 @@ def test_shared_link(self): model_a = chainer.ChainList(head.copy(), L.Linear(2, 3)) model_b = chainer.ChainList(head.copy(), L.Linear(2, 4)) - a_arrays = async_.extract_params_as_shared_arrays( - chainer.ChainList(model_a)) - b_arrays = async_.extract_params_as_shared_arrays( - chainer.ChainList(model_b)) + a_arrays = async_.extract_params_as_shared_arrays(model_a) + b_arrays = async_.extract_params_as_shared_arrays(model_b) print(('model_a shared_arrays', a_arrays)) print(('model_b shared_arrays', b_arrays)) diff --git a/tests/misc_tests/test_namedpersistent.py b/tests/misc_tests/test_namedpersistent.py new file mode 100644 index 000000000..b669a398e --- /dev/null +++ b/tests/misc_tests/test_namedpersistent.py @@ -0,0 +1,52 @@ +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA + +import chainer +import numpy + +import chainerrl + + +def test_namedpersistent(): + # This test case is adopted from + # https://github.com/chainer/chainer/pull/6788 + + l1 = chainer.Link() + with l1.init_scope(): + l1.x = chainer.Parameter(shape=(2, 3)) + + l2 = chainer.Link() + with l2.init_scope(): + l2.x = chainer.Parameter(shape=2) + l2.add_persistent( + 'l2_a', numpy.array([1, 2, 3], dtype=numpy.float32)) + + l3 = chainer.Link() + with l3.init_scope(): + l3.x = chainer.Parameter() + l3.add_persistent( + 'l3_a', numpy.array([1, 2, 3], dtype=numpy.float32)) + + c1 = chainer.Chain() + with c1.init_scope(): + c1.l1 = l1 + c1.add_link('l2', l2) + c1.add_persistent( + 'c1_a', numpy.array([1, 2, 3], dtype=numpy.float32)) + + c2 = chainer.Chain() + with c2.init_scope(): + c2.c1 = c1 + c2.l3 = l3 + c2.add_persistent( + 'c2_a', numpy.array([1, 2, 3], dtype=numpy.float32)) + namedpersistent = list(chainerrl.misc.namedpersistent(c2)) + assert ( + [(name, id(p)) for name, p in namedpersistent] == + [('/c2_a', id(c2.c2_a)), ('/c1/c1_a', id(c2.c1.c1_a)), + ('/c1/l2/l2_a', id(c2.c1.l2.l2_a)), ('/l3/l3_a', id(c2.l3.l3_a))])