diff --git a/setup.py b/setup.py index e1e3ddd..148a6d6 100644 --- a/setup.py +++ b/setup.py @@ -2,15 +2,17 @@ from setuptools import find_namespace_packages from setuptools import setup +import re def _get_sonnet_version(): with open('sonnet/__init__.py') as fp: for line in fp: if line.startswith('__version__'): - g = {} - exec(line, g) # pylint: disable=exec-used - return g['__version__'] + match = re.search(r"__version__\s*=\s*['\"]([^'\"]+)['\"]", line) + if match: + return match.group(1) + raise ValueError('Could not parse __version__ from line') raise ValueError('`__version__` not defined in `sonnet/__init__.py`') diff --git a/sonnet/src/nets/dnc/read.py b/sonnet/src/nets/dnc/read.py index 7c1bde6..bdd26a3 100644 --- a/sonnet/src/nets/dnc/read.py +++ b/sonnet/src/nets/dnc/read.py @@ -36,7 +36,7 @@ def read(memory, """ with tf.name_scope("read_memory"): if squash_before_access: - squash_op(weights) + memory = squash_op(memory) read_word = tf.matmul(weights, memory) if squash_after_access: read_word = squash_op(read_word)