diff --git a/TESTS/2DimensionOutput.py b/TESTS/2DimensionOutput.py new file mode 100644 index 0000000000..3ade99220d --- /dev/null +++ b/TESTS/2DimensionOutput.py @@ -0,0 +1,29 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant([1,1]), + lambda: [n,n]*fac(n-1)) + + +FacImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +x = tf.add(n, 1) +result = fac(x) +y = tf.add(result, [1,1]) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() +print(sess.run(y, feed_dict={n: 5})) + +writer.close() + +sess.close() diff --git a/TESTS/ackermann.py b/TESTS/ackermann.py new file mode 100644 index 0000000000..680e1a618a --- /dev/null +++ b/TESTS/ackermann.py @@ -0,0 +1,44 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +ack = function.Declare("Ack", [("m", tf.int32), ("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, tf.int32, func_name="Ack", out_names=["ret"]) +def AckImpl(m,n): + + def f1(): + ret = n + 1 + return ret + + def f2(): + def ff1(): + r = ack(m-1,1) + return r + + def ff2(): + r = ack(m-1, ack(m, n-1)) + return r + + ret = tf.cond(tf.equal(n, 0), ff1, ff2) + return ret + + return tf.cond(tf.equal(m, 0), f1, f2) + + +AckImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +m = tf.placeholder(tf.int32, shape=[]) +res = ack(m,n) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() + +#print(tf.get_default_graph().as_graph_def()) + +print(sess.run(res, feed_dict={m:2, n:3})) + +sess.close() + +writer.close() diff --git a/TESTS/create_worker.py b/TESTS/create_worker.py new file mode 100644 index 0000000000..62e80ade27 --- /dev/null +++ b/TESTS/create_worker.py @@ -0,0 +1,13 @@ +# Get task number from command line +import sys +task_number = int(sys.argv[1]) + +import tensorflow as tf + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) +server = tf.train.Server(cluster, job_name="local", task_index=task_number) + +print("Starting server #{}".format(task_number)) + +server.start() +server.join() diff --git a/TESTS/distributed/distr_factorial.py b/TESTS/distributed/distr_factorial.py new file mode 100644 index 0000000000..dfbf931b20 --- /dev/null +++ b/TESTS/distributed/distr_factorial.py @@ -0,0 +1,40 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) + +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + + def f1(): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + ret = tf.constant(1) + return ret + def f2(): + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + ret = n * fac(n - 1) + return ret + + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + pred = tf.less_equal(n, 1) + + return tf.cond(pred, f1, f2) + +FacImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +x = tf.add(n, 1) +result = fac(x) +y = tf.add(result, 1) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session("grpc://localhost:2222") as sess: + print(sess.run(y, feed_dict={n: 5})) + +writer.close() + diff --git a/TESTS/distributed/distr_fcallsg.py b/TESTS/distributed/distr_fcallsg.py new file mode 100644 index 0000000000..da241b6965 --- /dev/null +++ b/TESTS/distributed/distr_fcallsg.py @@ -0,0 +1,44 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) + +@function.Defun(tf.float32) +def G(x): + + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + ret = x + x + + return ret + + +@function.Defun(tf.float32, tf.float32) +def MyFunc(x, y): + + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + g1 = G(x) + g2 = G(y) + + ret = g1 + g2 + + return ret + + +# Building the graph. + +a = tf.constant([4.0], name="a") +b = tf.placeholder(tf.float32, name="MyPlaceHolder") + +add = tf.add(a, b, name="add") +sub = tf.subtract(a, b, name="sub") + +ret = MyFunc(add, sub, name='mycall') + +#x = tf.add(c, d) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session("grpc://localhost:2222") as sess: + print(sess.run([ret], feed_dict={b:1})) + +writer.close() diff --git a/TESTS/distributed/distr_fibonacci.py b/TESTS/distributed/distr_fibonacci.py new file mode 100644 index 0000000000..e8c3e59f88 --- /dev/null +++ b/TESTS/distributed/distr_fibonacci.py @@ -0,0 +1,39 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) + +fib = function.Declare("Fib", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fib", out_names=["ret"]) +def FibImpl(n): + + def f1(): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + ret = tf.constant(1) + return ret + def f2(): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + fib1 = fib(n-1) + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + fib2 = fib(n-2) + + return fib1 + fib2 + + return tf.cond(tf.less_equal(n, 1), f1, f2) + +FibImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +x = fib(n) + +res = tf.add(x, 1) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session("grpc://localhost:2222") as sess: + print(sess.run(res, feed_dict={n: 20})) + +writer.close() diff --git a/TESTS/distributed/distr_fog.py b/TESTS/distributed/distr_fog.py new file mode 100644 index 0000000000..430665a7de --- /dev/null +++ b/TESTS/distributed/distr_fog.py @@ -0,0 +1,36 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) + +@function.Defun(tf.float32) +def G(x): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + add = x + 1 + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + ret = x * add + return ret + +@function.Defun(tf.float32) +def F(x): + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + add = x + 1 + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + ret = x * add + return ret + + +a = tf.constant([4.0], name="a") +b = tf.placeholder(tf.float32, name="MyPlaceHolder") + +add = tf.add(a, b, name="add") + +ret = F(G(add), name='mycall') + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session("grpc://localhost:2222") as sess: + print(sess.run([ret], feed_dict={b:1})) + +writer.close() + diff --git a/TESTS/distributed/distr_funcSimple.py b/TESTS/distributed/distr_funcSimple.py new file mode 100644 index 0000000000..1fbe935696 --- /dev/null +++ b/TESTS/distributed/distr_funcSimple.py @@ -0,0 +1,36 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) + +@function.Defun(tf.int32, tf.int32) +def MyFunc(x, y): + + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + add1 = x + y + + return [add1, x - y] + + +# Building the graph. + +a = tf.constant([4], name="x") +b = tf.placeholder(tf.int32, name="MyPlaceHolder") + +with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + add = tf.add(a, b, name="add") + +with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + sub = tf.subtract(a, b, name="sub") + +[c,d] = MyFunc(add, sub, name='mycall') + +x = tf.add(c, d) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session("grpc://localhost:2222") as sess: + print(sess.run([x], feed_dict={b:1})) +writer.close() diff --git a/TESTS/distributed/distr_mutrec.py b/TESTS/distributed/distr_mutrec.py new file mode 100644 index 0000000000..864809e57e --- /dev/null +++ b/TESTS/distributed/distr_mutrec.py @@ -0,0 +1,49 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223"]}) + +f = function.Declare("F", [("n", tf.int32)], [("ret", tf.int32)]) +g = function.Declare("G", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="F", out_names=["ret"]) +def FImpl(n): + + def f1(): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + ret = tf.constant(1) + return ret + def f2(): + with tf.device("/job:local/replica:0/task:0/device:CPU:0"): + x = n - 1 + ret = g(x) + return ret + +# with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + pred = tf.less_equal(n, 1) + + return tf.cond(pred, f1, f2) + + +@function.Defun(tf.int32, func_name="G", out_names=["ret"]) +def GImpl(n): + + with tf.device("/job:local/replica:0/task:1/device:CPU:0"): + x = n - 1 + ret = f(x) + return ret + + +FImpl.add_to_graph(tf.get_default_graph()) +GImpl.add_to_graph(tf.get_default_graph()) + + +n = tf.placeholder(tf.int32, name="MyPlaceHolder") +x = f(n) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session("grpc://localhost:2222") as sess: + print(sess.run([x], feed_dict={n:4})) + +writer.close() diff --git a/TESTS/factorial.py b/TESTS/factorial.py new file mode 100644 index 0000000000..25542860f9 --- /dev/null +++ b/TESTS/factorial.py @@ -0,0 +1,29 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant(1), + lambda: n * fac(n - 1)) + + +FacImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +x = tf.add(n, 1) +result = fac(x) +y = tf.add(result, 1) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() +print(sess.run(y, feed_dict={n: 5})) + +writer.close() + +sess.close() diff --git a/TESTS/fcallsg.py b/TESTS/fcallsg.py new file mode 100644 index 0000000000..9d74584809 --- /dev/null +++ b/TESTS/fcallsg.py @@ -0,0 +1,33 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +@function.Defun(tf.float32) +def G(x): + return [x + x] + + +@function.Defun(tf.float32, tf.float32) +def MyFunc(x, y): + return [G(x), G(y)] + + +# Building the graph. + +a = tf.constant([4.0], name="a") +b = tf.placeholder(tf.float32, name="MyPlaceHolder") + +add = tf.add(a, b, name="add") +sub = tf.subtract(a, b, name="sub") + +[c,d] = MyFunc(add, sub, name='mycall') + +x = tf.add(c, d) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session() as sess: # no need to manually close the session +# print(sess.run([add, sub], feed_dict={b:1})) + print(sess.run([x], feed_dict={b:1})) + #writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +writer.close() diff --git a/TESTS/ff.py b/TESTS/ff.py new file mode 100644 index 0000000000..2915db693e --- /dev/null +++ b/TESTS/ff.py @@ -0,0 +1,27 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +fac = function.Declare("Fac", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(n): + return fac(n) + + +FacImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +x = tf.add(n, 1) +result = fac(x) +#y = tf.add(result, 1) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() +print(sess.run(result, feed_dict={n: 5})) + +writer.close() + +sess.close() diff --git a/TESTS/fibonacci.py b/TESTS/fibonacci.py new file mode 100644 index 0000000000..680f8be425 --- /dev/null +++ b/TESTS/fibonacci.py @@ -0,0 +1,29 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +fib = function.Declare("Fib", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="Fib", out_names=["ret"]) +def FibImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant(1), + lambda: fib(n-1) + fib(n-2)) + +FibImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +x = fib(n) + +res = tf.add(x, 1) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() + +#print(tf.get_default_graph().as_graph_def()) + + +writer.close() +print(sess.run(res, feed_dict={n: 24})) + +sess.close() diff --git a/TESTS/fog.py b/TESTS/fog.py new file mode 100644 index 0000000000..f6a21e7a8f --- /dev/null +++ b/TESTS/fog.py @@ -0,0 +1,26 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +@function.Defun(tf.float32) +def G(x): + return x * x + + +@function.Defun(tf.float32) +def F(x): + return x + x + + +a = tf.constant([4.0], name="a") +b = tf.placeholder(tf.float32, name="MyPlaceHolder") + +add = tf.add(a, b, name="add") + +ret = F(G(add), name='mycall') + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session() as sess: + print(sess.run([ret], feed_dict={b:1})) + +writer.close() diff --git a/TESTS/funcSimple.py b/TESTS/funcSimple.py new file mode 100644 index 0000000000..7466491f33 --- /dev/null +++ b/TESTS/funcSimple.py @@ -0,0 +1,30 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +@function.Defun(tf.float32, tf.float32) +def MyFunc(x, y): + return [x + y, x - y] + + +# Building the graph. + +a = tf.constant([4.0], name="a") +b = tf.placeholder(tf.float32, name="MyPlaceHolder") + +add = tf.add(a, b, name="add") +sub = tf.subtract(a, b, name="sub") + +[c,d] = MyFunc(add, sub, name='mycall') + +x = tf.add(c, d) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session() as sess: # no need to manually close the session +# print(sess.run([add, sub], feed_dict={b:1})) + print(sess.run([x], feed_dict={b:1})) + #writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +writer.close() diff --git a/TESTS/hello.py b/TESTS/hello.py new file mode 100644 index 0000000000..c804b5f983 --- /dev/null +++ b/TESTS/hello.py @@ -0,0 +1,5 @@ +# Python +import tensorflow as tf +hello = tf.constant('Hello, TensorFlow!') +sess = tf.Session() +print(sess.run(hello)) diff --git a/TESTS/mutrec.py b/TESTS/mutrec.py new file mode 100644 index 0000000000..370f3793a5 --- /dev/null +++ b/TESTS/mutrec.py @@ -0,0 +1,31 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +f = function.Declare("F", [("n", tf.int32)], [("ret", tf.int32)]) +g = function.Declare("G", [("n", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, func_name="F", out_names=["ret"]) +def FImpl(n): + return tf.cond(tf.less_equal(n, 1), + lambda: tf.constant(1), + lambda: g(n - 1)) + +@function.Defun(tf.int32, func_name="G", out_names=["ret"]) +def GImpl(n): + return f(n) + +# Building the graph. + +FImpl.add_to_graph(tf.get_default_graph()) +GImpl.add_to_graph(tf.get_default_graph()) + + +n = tf.placeholder(tf.int32, name="MyPlaceHolder") +x = f(n) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +with tf.Session() as sess: # no need to manually close the session + print(sess.run([x], feed_dict={n:4})) + +writer.close() diff --git a/TESTS/not_lazy.py b/TESTS/not_lazy.py new file mode 100644 index 0000000000..b694c1584e --- /dev/null +++ b/TESTS/not_lazy.py @@ -0,0 +1,29 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +fac = function.Declare("Fac", [("x", tf.int32), ("y", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, tf.int32, func_name="Fac", out_names=["ret"]) +def FacImpl(x, y): + return tf.cond(tf.less_equal(x, 1), + lambda: tf.constant(1), + lambda: fac(x-1, fac(x,y))) + +FacImpl.add_to_graph(tf.get_default_graph()) + +x = tf.placeholder(tf.int32, shape=[]) +result = fac(x, 2) + + +y = tf.add(result, 1) + +#print(tf.get_default_graph().as_graph_def()) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() +print(sess.run(y, feed_dict={x:2})) + +writer.close() + +sess.close() diff --git a/TESTS/primes.py b/TESTS/primes.py new file mode 100644 index 0000000000..7668c983fa --- /dev/null +++ b/TESTS/primes.py @@ -0,0 +1,62 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +primes = function.Declare("primes", [("x", tf.int32)], [("ret", tf.int32)]) +findPrimePlus = function.Declare("findPrimePlus", [("n", tf.int32),("i", tf.int32)], [("ret", tf.int32)]) +findPrimeMinus = function.Declare("findPrimeMinus", [("n", tf.int32),("i", tf.int32)], [("ret", tf.int32)]) +testPrime = function.Declare("testPrime", [("n", tf.int32),("i", tf.int32)], [("ret", tf.bool)]) + + +@function.Defun(tf.int32, func_name="primes", out_names=["ret"]) +def PrimesImpl(n): + return tf.cond(tf.less_equal(n, 0), + lambda: 2, + lambda: tf.cond(tf.equal(n, 1), + lambda: 3, + lambda: findPrimeMinus(n-2,1) + )) +PrimesImpl.add_to_graph(tf.get_default_graph()) + +@function.Defun(tf.int32, tf.int32, func_name="findPrimeMinus", out_names=["ret"]) +def FindPrimeMinusImpl(n,i): + return tf.cond(testPrime(6*i-1, 1), + lambda: tf.cond(tf.equal(n, 0), + lambda: 6*i-1, + lambda: findPrimePlus(n-1,i)), + lambda: findPrimePlus(n,i)) +FindPrimeMinusImpl.add_to_graph(tf.get_default_graph()) + +@function.Defun(tf.int32, tf.int32, func_name="findPrimePlus", out_names=["ret"]) +def FindPrimePlusImpl(n,i): + return tf.cond(testPrime(6*i-1, 1), + lambda: tf.cond(tf.equal(n, 0), + lambda: 6*i-1, + lambda: findPrimMinus(n-1,i+1)), + lambda: findPrimeMinus(n,i+1)) +FindPrimePlusImpl.add_to_graph(tf.get_default_graph()) + + +@function.Defun(tf.int32, tf.int32, func_name="testPrime", out_names=["ret"]) +def TestPrimeImpl(n,i): + return tf.cond(tf.greater((6*i-1)*(6*i-1), n), + lambda: True, + lambda: tf.cond(tf.equal(tf.mod(n, (6*i-1)), 0), + lambda: False, + lambda: tf.cond(tf.equal(tf.mod(n, (6*i-1)), 0), + lambda: False, + lambda: testPrime(n, i+1)))) +TestPrimeImpl.add_to_graph(tf.get_default_graph()) + +n = tf.placeholder(tf.int32, shape=[]) +res = primes(n) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() + +#print(tf.get_default_graph().as_graph_def()) + +writer.close() +print(sess.run(res, feed_dict={n:7500})) + +sess.close() diff --git a/TESTS/takeuchi.py b/TESTS/takeuchi.py new file mode 100644 index 0000000000..a90e1b78a6 --- /dev/null +++ b/TESTS/takeuchi.py @@ -0,0 +1,28 @@ +import tensorflow as tf +from tensorflow.python.framework import function + +ack = function.Declare("tak", [("x", tf.int32), ("y", tf.int32), ("z", tf.int32)], [("ret", tf.int32)]) + +@function.Defun(tf.int32, tf.int32, tf.int32, func_name="Tak", out_names=["ret"]) +def TakImpl(x,y,z): + return tf.cond(tf.less(y, x), + lambda: tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y)) + lambda: z) + +TakImpl.add_to_graph(tf.get_default_graph()) + +x = tf.placeholder(tf.int32, shape=[]) +y = tf.placeholder(tf.int32, shape=[]) +z = tf.placeholder(tf.int32, shape=[]) +res = tak(x,y,z) + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() + +#print(tf.get_default_graph().as_graph_def()) + +writer.close() +print(sess.run(res, feed_dict={x:24, y:16, z:8})) + +sess.close() diff --git a/TESTS/while.py b/TESTS/while.py new file mode 100644 index 0000000000..8ca58ad6b4 --- /dev/null +++ b/TESTS/while.py @@ -0,0 +1,14 @@ +import tensorflow as tf + +n = tf.constant(4) +res = tf.while_loop(lambda i, n: i > 0, lambda i, n: (i-1, n*2), [4, 1]) + + +writer = tf.summary.FileWriter('./graphs', tf.get_default_graph()) + +sess = tf.Session() +result = sess.run([res]) +print(result) + +writer.close() +sess.close() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b46b572999..c2f3da1242 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -567,6 +567,7 @@ tf_gen_op_libs( "data_flow_ops", "dataset_ops", "function_ops", + "function_control_ops", "functional_ops", "image_ops", "io_ops", @@ -648,6 +649,7 @@ cc_library( ":dataset_ops_op_lib", ":function_ops_op_lib", ":functional_ops_op_lib", + ":function_control_ops_op_lib", ":image_ops_op_lib", ":io_ops_op_lib", ":linalg_ops_op_lib", @@ -780,6 +782,7 @@ cc_library( "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:function_control_ops", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", @@ -2937,11 +2940,13 @@ tf_cc_test( ":testlib", "//tensorflow/cc:cc_ops", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:function_control_ops", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:fifo_queue_op", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:queue_ops", @@ -2978,11 +2983,13 @@ tf_cc_test( # Link with support for TensorFlow Debugger (tfdbg). "//tensorflow/core/debug", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:function_control_ops", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:fifo_queue_op", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:queue_ops", @@ -3235,6 +3242,7 @@ tf_cc_test( "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:data_flow", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:function_control_ops", "//tensorflow/core/kernels:math", "//third_party/eigen3", ], diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 8674831eac..5cd1e042c8 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -19,6 +19,9 @@ limitations under the License. #include #include +#include +#include + #include "tensorflow/core/common_runtime/constant_folding.h" #include "tensorflow/core/common_runtime/debugger_state_interface.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -581,8 +584,13 @@ Status DirectSession::Run(const RunOptions& run_options, return errors::Cancelled("Run call was cancelled"); } + clock_t t; for (const auto& item : executors_and_keys->items) { + t = clock(); + item.executor->RunAsync(args, barrier->Get()); + + } WaitForNotification(&run_state, &step_cancellation_manager, @@ -590,6 +598,10 @@ Status DirectSession::Run(const RunOptions& run_options, ? run_options.timeout_in_ms() : operation_timeout_in_ms_); + t = clock() - t; + std::cout << "time: " << t << " miliseconds" << std::endl; + std::cout << "time: " << t*1.0/CLOCKS_PER_SEC << " seconds" << std::endl; + if (!cancellation_manager_->DeregisterCallback(cancellation_token)) { // The step has been cancelled: make sure we don't attempt to receive the // outputs as this would make it block forever. diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 67a4296442..ed3a889ee0 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -216,10 +216,14 @@ struct NodeItem { bool is_merge : 1; // True iff IsMerge(node) bool is_enter : 1; // True iff IsEnter(node) bool is_exit : 1; // True iff IsExit(node) + bool is_call : 1; // True iff IsCall(node) + bool is_return : 1; // True iff IsReturn(node) bool is_control_trigger : 1; // True iff IsControlTrigger(node) bool is_sink : 1; // True iff IsSink(node) // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) bool is_enter_exit_or_next_iter : 1; + // True iff IsCall(node) || IsReturn(node) + bool is_call_or_return : 1; // Cached values of node->num_inputs() and node->num_outputs(), to // avoid levels of indirection. @@ -233,6 +237,11 @@ struct NodeItem { // Number of output edges. size_t num_output_edges; + string frame_name; // cache the attribute if is_enter | is-exit | is_call | is_return + string dyn_frame_name; // cache the attribute if is_enter | is-exit | is_call | is_return + + int call_id = -1; + PendingCounts::Handle pending_id; const EdgeInfo* output_edge_list() const { return output_edge_base(); } @@ -618,6 +627,8 @@ Status ExecutorImpl::Initialize() { EnsureFrameInfo(it)->nodes = new std::vector; } + std::unordered_map input_count; + // Preprocess every node in the graph to create an instance of op // kernel for each node. for (const Node* n : graph_->nodes()) { @@ -649,10 +660,14 @@ Status ExecutorImpl::Initialize() { item->is_merge = IsMerge(n); item->is_enter = IsEnter(n); item->is_exit = IsExit(n); + item->is_call = IsCall(n); + item->is_return = IsReturn(n); item->is_control_trigger = IsControlTrigger(n); item->is_sink = IsSink(n); item->is_enter_exit_or_next_iter = (IsEnter(n) || IsExit(n) || IsNextIteration(n)); + item->is_call_or_return = + (IsCall(n) || IsReturn(n)); // Compute the maximum values we'll store for this node in the // pending counts data structure, and allocate a handle in @@ -666,9 +681,22 @@ Status ExecutorImpl::Initialize() { // Initialize static information about the frames in the graph. frame_info->nodes->push_back(n); if (IsEnter(n)) { - string enter_name; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name)); - EnsureFrameInfo(enter_name)->input_count++; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &item->frame_name)); + item->dyn_frame_name = item->frame_name; + } + if (item->is_call_or_return) { + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &item->frame_name)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "call_id", &item->call_id)); + item->dyn_frame_name = strings::StrCat(item->call_id); + } + if (item->is_enter) { + EnsureFrameInfo(item->frame_name)->input_count++; + } + if (item->is_call) { + input_count[item->dyn_frame_name]++; + // The following assumes that all the calls of same function have the same number of inputs + // which is of course apparent for a well-formed graph (produced by the transformation) + EnsureFrameInfo(item->frame_name)->input_count = input_count[item->dyn_frame_name]; } } @@ -976,6 +1004,9 @@ class ExecutorState { // frame_name. uint64 frame_id; + + int call_id = -1; + // The iteration id of its parent frame when this frame is created. // -1 if there is no parent frame. The frame_name/parent_iter pair // uniquely identifies this FrameState. @@ -1021,6 +1052,13 @@ class ExecutorState { int total_input_tensors = 0; std::vector* nodes = nullptr; + // Mapping from frame name to outstanding frames. A new frame is created + // at some iteration of an active frame. So the unique key for the new + // child frame is composed of the name of the parent frame, the iteration + // number at which the parent frame is creating the new frame, and the + // name of the new frame from nodedef. + gtl::FlatMap outstanding_child_frames_ GUARDED_BY(mu); + // Lock ordering: ExecutorState.mu_ < mu. mutex mu; @@ -1203,17 +1241,15 @@ class ExecutorState { mutex mu_; Status status_ GUARDED_BY(mu_); - // Mapping from frame name to outstanding frames. A new frame is created - // at some iteration of an active frame. So the unique key for the new - // child frame is composed of the name of the parent frame, the iteration - // number at which the parent frame is creating the new frame, and the - // name of the new frame from nodedef. - gtl::FlatMap outstanding_frames_ GUARDED_BY(mu_); - // The unique name of a frame. inline string MakeFrameName(FrameState* frame, int64 iter_id, const string& name) { - return strings::StrCat(frame->frame_name, ";", iter_id, ";", name); + //return strings::StrCat(frame->frame_name, frame->frame_id, ";", iter_id, ";", name); + return strings::StrCat(frame->frame_id, ";", iter_id, ";", name); + } + // The unique name of a frame. + inline string MakeFrameName(FrameState* frame, const string& name) { + return strings::StrCat(frame->frame_id, ";", name); } // Find an existing or create a new child frame in the frame 'frame' at @@ -1316,13 +1352,10 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) root_frame_->iterations[0] = new IterationState( root_frame_->pending_counts, root_frame_->total_input_tensors); - outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); } ExecutorState::~ExecutorState() { - for (auto name_frame : outstanding_frames_) { - delete name_frame.second; - } + for (auto it : device_context_map_) { it->Unref(); } @@ -1350,6 +1383,8 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g, } } + std::unordered_map call_id_to_call_node_id; + while (!ready.empty()) { Node* curr_node = ready.front(); int curr_id = curr_node->id(); @@ -1366,6 +1401,38 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g, parent = parent_nodes[curr_id]; frame_name = cf_info->frame_names[parent->id()]; parent = parent_nodes[parent->id()]; + } else if (IsCall(curr_node)) { + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name)); + + int call_id; + + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(), "call_id", &call_id)); + // we assume that call_id is unique and we don't need to concat with frame_name + // to make it unique. + + call_id_to_call_node_id.emplace(call_id, curr_id); + + parent = curr_node; + + } else if (IsReturn(curr_node)) { + + int call_id; + + TF_RETURN_IF_ERROR( + GetNodeAttr(curr_node->attrs(), "call_id", &call_id)); + + auto it = call_id_to_call_node_id.find(call_id); + + if (it != call_id_to_call_node_id.end()) { + int call_node_id = it->second; + parent = parent_nodes[call_node_id]; + frame_name = cf_info->frame_names[call_node_id]; + } else { + ready.push_back(curr_node); + continue; + } } else { parent = parent_nodes[curr_id]; frame_name = cf_info->frame_names[curr_id]; @@ -1375,6 +1442,8 @@ Status ExecutorImpl::BuildControlFlowInfo(const Graph* g, Node* out = out_edge->dst(); const int out_id = out->id(); + if (IsReturn(out) && out_edge->IsControlEdge()) continue; + // Add to ready queue if not visited. bool is_visited = visited[out_id]; if (!is_visited) { @@ -1917,7 +1986,15 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, FrameState* output_frame = input_frame; int64 output_iter = input_iter; - if (!item->is_enter_exit_or_next_iter) { + if (vlog_) { + VLOG(2) << "Propagate Outputs: " << node->name(); + VLOG(2) << "Frame: " << input_frame->frame_name; + } + + printf("Propagate Outputs: %s, am i alive? %d\n", node->name().c_str(), !is_dead); + printf("Frame: %s\n", input_frame->frame_name.c_str()); + + if (!item->is_enter_exit_or_next_iter && !item->is_call_or_return) { // Fast path for nodes types that don't need special handling DCHECK_EQ(input_frame, output_frame); // Normal path for most nodes @@ -1963,6 +2040,34 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); } + } else if (item->is_call) { +// if (is_dead) { +// // Stop the deadness propagation. +// output_frame = nullptr; +// } else { + FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame); + output_iter = 0; + { + const NodeItem *item = impl_->gview_.node(node->id()); + mutex_lock l(output_frame->mu); + output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + output_frame->num_pending_inputs--; + } +// } + is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); + } else if (item->is_return) { +// if (is_dead) { +// // Stop the deadness propagation. +// output_frame = nullptr; +// } else { + output_frame = input_frame->parent_frame; + output_iter = input_frame->parent_iter; + { + mutex_lock l(output_frame->mu); + output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); + } +// } + is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready); } else { DCHECK(IsNextIteration(node)); mutex_lock l(input_frame->mu); @@ -1985,7 +2090,7 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, } } if (output_frame != nullptr) { - // This is the case when node is not Enter, Exit, or NextIteration. + // This is the case when node is not Enter, Exit, NextIteration, Call or Return. DCHECK(input_frame == output_frame); output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready); } @@ -2217,15 +2322,18 @@ void ExecutorState::DumpState() { mutex_lock l(mu_); if (!dumped_on_error_) { LOG(WARNING) << "Dumping state"; - for (auto& frame : outstanding_frames_) { - LOG(WARNING) << frame.first; - FrameState* frame_state = frame.second; - mutex_lock frame_lock(frame_state->mu); - for (IterationState* iteration : frame_state->iterations) { - LOG(WARNING) << " Iteration:"; - DumpIterationState(frame_state, iteration); - } - } + + // TODO : Make it print all this info recursively! + +// for (auto& frame : outstanding_frames_) { +// LOG(WARNING) << frame.first; +// FrameState* frame_state = frame.second; +// mutex_lock frame_lock(frame_state->mu); +// for (IterationState* iteration : frame_state->iterations) { +// LOG(WARNING) << " Iteration:"; +// DumpIterationState(frame_state, iteration); +// } +// } dumped_on_error_ = true; } } @@ -2251,16 +2359,20 @@ void ExecutorState::Finish() { void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node, FrameState** child) { - // Get the child frame name. - string enter_name; - Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name); - DCHECK(s.ok()) << s; - const string child_name = MakeFrameName(frame, iter, enter_name); + const GraphView& gview = impl_->gview_; + const NodeItem* item = gview.node(node->id()); + Status s; + const string& enter_name = item->frame_name; + const string& dyn_frame_name = item->dyn_frame_name; + const string child_name = item->is_call ? + MakeFrameName(frame, dyn_frame_name) : + MakeFrameName(frame, iter, dyn_frame_name); + const uint64 child_id = Hash64(child_name); { - mutex_lock executor_lock(mu_); - auto it = outstanding_frames_.find(child_name); - if (it != outstanding_frames_.end()) { + mutex_lock frame_lock(frame->mu); + auto it = frame->outstanding_child_frames_.find(child_id); + if (it != frame->outstanding_child_frames_.end()) { *child = it->second; return; } @@ -2271,13 +2383,20 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, if (vlog_) VLOG(2) << "Create frame: " << child_name; int parallel_iters; - s = GetNodeAttr(node->attrs(), "parallel_iterations", ¶llel_iters); - DCHECK(s.ok()) << s; + if (IsCall(node)) { + // since this is not a loop scope there are no iterations + parallel_iters = 1; + } else { + s = GetNodeAttr(node->attrs(), "parallel_iterations", ¶llel_iters); + DCHECK(s.ok()) << s; + } + FrameState* temp = new FrameState(impl_, parallel_iters); temp->frame_name = child_name; - temp->frame_id = Hash64(child_name); + temp->frame_id = child_id; temp->parent_frame = frame; temp->parent_iter = iter; + temp->call_id = item->call_id; temp->InitializeFrameInfo(enter_name); // 'iterations' is a fixed-length circular buffer. @@ -2287,14 +2406,13 @@ void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, new IterationState(temp->pending_counts, temp->total_input_tensors); { - mutex_lock executor_lock(mu_); - auto it = outstanding_frames_.find(child_name); - if (it != outstanding_frames_.end()) { + mutex_lock frame_lock(frame->mu); + auto it = frame->outstanding_child_frames_.find(child_id); + if (it != frame->outstanding_child_frames_.end()) { *child = it->second; } else { - mutex_lock frame_lock(frame->mu); frame->GetIteration(iter)->outstanding_frame_count++; - outstanding_frames_[child_name] = temp; + frame->outstanding_child_frames_[child_id] = temp; *child = temp; temp = nullptr; } @@ -2307,7 +2425,7 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { FrameState* parent_frame = frame->parent_frame; const int64 parent_iter = frame->parent_iter; if (parent_frame != nullptr) { - mutex_lock paranet_frame_lock(parent_frame->mu); + mutex_lock parent_frame_lock(parent_frame->mu); // Propagate all the dead exits to the parent frame. for (const Node* node : frame->dead_exits) { auto parent_iter_state = parent_frame->GetIteration(parent_iter); @@ -2357,8 +2475,10 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { const string& frame_name = frame->frame_name; if (vlog_) VLOG(2) << "Delete frame " << frame_name; { - mutex_lock executor_lock(mu_); - outstanding_frames_.erase(frame_name); + if (parent_frame != nullptr) { + mutex_lock parent_frame_lock(parent_frame->mu); + parent_frame->outstanding_child_frames_.erase(frame->frame_id); + } } delete frame; } @@ -2448,13 +2568,36 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, } } } else { + // In case of "Return" dst_node, + // we compare node's frame attr with current frame name + // if they are different, ignore this op + if (dst_item->is_return) { + if (dst_item->call_id != call_id) + continue; + } + const bool increment_dead = (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value)); int pending, dead; iter_state->adjust_for_activation(dst_pending_id, increment_dead, &pending, &dead); - dst_dead = (dead > 0); - dst_ready = (pending == 0); + + + if (dst_item->is_return && increment_dead) { + // The only dead input a Return op will ever may get + // is the control input propagated to it from a corresponding + // dead Call op in case of untaken branch. So at this point + // we are certain that Return op will never receive another input. + // Therefore, we force it to be added in queue for the sake of + // deadness propagation and we adjust it for activation once more, + // so that it no longer waits for another (never coming) input. + iter_state->adjust_for_activation(dst_pending_id, increment_dead, + &pending, &dead); + } + + dst_dead = (dead > 0); + dst_ready = (pending == 0); + } if (dst_need_input) { @@ -2469,6 +2612,7 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item, // Add dst to the ready queue if it's ready if (dst_ready) { + printf(" Add in queue: %s\n", dst_item->node->name().c_str()); if (dst_item->is_control_trigger) dst_dead = false; ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead)); iter_state->outstanding_ops++; diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 4bd40c7978..7d09fe4f59 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/placer.h" @@ -361,6 +363,27 @@ Status GraphExecutionState::OptimizeGraph( optimized_graph->reset(new Graph(OpRegistry::Global())); TF_RETURN_IF_ERROR( ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get())); +/*******************************************************************************************/ + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("Fully_Optimized"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = new_graph.ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return tensorflow::errors::ResourceExhausted("Failed to allocate memory to serialize message of type '" + ,new_graph.GetTypeName(), "' and size ", proto_size); + } + new_graph.SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); +/*******************************************************************************************/ + + VLOG(1) << "Transformation passed successfully"; + // The graph conversion sets the requested device names but not the assigned // device names. However, since at this point the graph is placed TF expects // an assigned device name for every node. Therefore we copy the requested diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 995422644a..d75ec3f644 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -19,6 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/profile_handler.h" #include "tensorflow/core/common_runtime/stats_publisher_interface.h" @@ -280,6 +284,39 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions( std::unordered_map graph_defs; Status s = DoBuildPartitions(popts, &graph_defs); if (s.ok()) { + + +printf("\n\n MASTER PARTITIONS:\n"); +int i=0; +for (const auto& it: graph_defs) { + string dvc = it.first; + const GraphDef* graphDef = &it.second; + printf("\n\nDeviceName :'%s'\n", dvc.c_str()); + printf("Partition GraphDef:\n %s\n", SummarizeGraphDef(*graphDef).c_str()); + + string p = strings::StrCat("Partition", i); i++; + EventsWriter writer(p); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = graphDef->ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + graphDef->GetTypeName(), "' and size ", proto_size); + } + graphDef->SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + +} + + + + // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain // valid after the call to DoRegisterPartitions begins, so // `stats_publisher_` must make a copy if it wants to retain the @@ -1543,9 +1580,20 @@ Status MasterSession::DoRunWithLocalExecution( pss.collect_rpcs = ph->should_collect_rpcs(); } +// For future "execution-time" testing - when run on truly seperate machines +// clock_t t; +// t = clock(); + Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp, &cancellation_manager_, false); if (s.ok()) { + +// +// t = clock() - t; +// std::cout << "time: " << t << " miliseconds" << std::endl; +// std::cout << "time: " << t*1.0/CLOCKS_PER_SEC << " seconds" << std::endl; + + pss.end_micros = Env::Default()->NowMicros(); // Schedule post-processing and cleanup to be done asynchronously. diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 2ad0081e1f..64db8e2207 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -61,6 +61,8 @@ const std::unordered_map& Node::kNodeClassTable = REF_CLASS("Enter", NC_ENTER), REF_CLASS("Exit", NC_EXIT), REF_CLASS("NextIteration", NC_NEXT_ITERATION), + REF_CLASS("Call", NC_CALL), + REF_CLASS("Return", NC_RETURN), {"LoopCond", NC_LOOP_COND}, {"ControlTrigger", NC_CONTROL_TRIGGER}, {"_Send", NC_SEND}, diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 5a31a6216b..a1f8e6d46d 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -142,6 +142,8 @@ class Node { bool IsEnter() const { return class_ == NC_ENTER; } bool IsExit() const { return class_ == NC_EXIT; } bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; } + bool IsCall() const { return class_ == NC_CALL; } + bool IsReturn() const { return class_ == NC_RETURN; } bool IsLoopCond() const { return class_ == NC_LOOP_COND; } bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; } bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; } @@ -157,7 +159,7 @@ class Node { bool IsControlFlow() const { return (class_ != NC_OTHER) && // Fast path (IsSwitch() || IsMerge() || IsEnter() || IsExit() || - IsNextIteration()); + IsNextIteration() || IsCall() || IsReturn()); } bool IsHostSend() const { return class_ == NC_HOST_SEND; } bool IsHostRecv() const { return class_ == NC_HOST_RECV; } @@ -219,6 +221,8 @@ class Node { NC_ENTER, NC_EXIT, NC_NEXT_ITERATION, + NC_CALL, + NC_RETURN, NC_LOOP_COND, NC_CONTROL_TRIGGER, NC_SEND, @@ -655,6 +659,8 @@ inline bool IsMerge(const Node* node) { return node->IsMerge(); } inline bool IsEnter(const Node* node) { return node->IsEnter(); } inline bool IsExit(const Node* node) { return node->IsExit(); } inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); } +inline bool IsCall(const Node* node) { return node->IsCall(); } +inline bool IsReturn(const Node* node) { return node->IsReturn(); } inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); } inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); } inline bool IsSend(const Node* node) { return node->IsSend(); } diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 15f7b9fe8c..cb0f230311 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -52,6 +52,16 @@ inline bool IsNextIteration(const NodeDef& node_def) { node_def.op() == "RefNextIteration"; } +inline bool IsCall(const NodeDef& node_def) { + return node_def.op() == "Call" || + node_def.op() == "RefCall"; +} + +inline bool IsReturn(const NodeDef& node_def) { + return node_def.op() == "Return" || + node_def.op() == "RefReturn"; +} + bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { using ::tensorflow::strings::Scanner; return Scanner(s) @@ -137,7 +147,10 @@ class GraphConstructor { original_versions_(g->versions()), refiner_(refiner), return_tensors_(return_tensors), - unused_input_map_keys_(unused_input_map_keys) {} + unused_input_map_keys_(unused_input_map_keys) { + + SetFunctionReturningNodes(node_defs); + } Status TryImport() { TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); @@ -183,7 +196,52 @@ class GraphConstructor { void AddPrefixToNodeDef(const std::vector& input_already_exists, NodeDef* node_def); - // From constructor + bool IsReturningNode(const NodeDef& node_def) { + return (function_returning_nodes_.find(node_def.name()) != + function_returning_nodes_.end()); + } + + void SetFunctionReturningNodes(const NodeDefSlice& node_defs) { + + std::unordered_map> returning_nodes; + + for (int n = 0; n < node_defs.size(); ++n) { + const NodeDef& node_def = *node_defs[n]; + if (IsReturn(node_def)) { + // Nodes that send their output to "Return" nodes are + // function Returning Nodes and in case of recursive functions + // those nodes are part of graph cycles. + for (const auto& input : node_def.input()) { + // In order to detect the recursion cycles we depend on + // the fact that a recursive function's returning node, + // will be sending outputs to at least 2 "Return" nodes + // with different "call_id" attributes (same "call_id" + // attrs would mean that they belong in the same function call + // but they correspond to different function outputs) + if (!StringPiece(input).starts_with("^")) { + int call_id; + GetNodeAttr(AttrSlice(node_def), "call_id", &call_id); + + size_t pos; + string prevNode; + ((pos = input.find(":")) != std::string::npos) ? + (prevNode = input.substr(0, pos)) : (prevNode = input); + + returning_nodes[prevNode].emplace(call_id); + } + } + } + } + for (auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + function_returning_nodes_.insert(retnode.first); + } + } + } + + + // From constructor const Options opts_; const NodeDefSlice node_defs_; const VersionDef* versions_; @@ -251,6 +309,8 @@ class GraphConstructor { int dst_index; }; std::vector back_edges_; + + std::unordered_set function_returning_nodes_; }; // This could be expensive but we don't expect to call it often, if at all (only @@ -398,21 +458,49 @@ std::unordered_set GetNextIterationNodes( return next_iteration_nodes; } +std::unordered_set GetCallNodes( + const GraphConstructor::NodeDefSlice& node_defs) { + std::unordered_set call_nodes; + + for (int n = 0; n < node_defs.size(); ++n) { + const NodeDef& node_def = *node_defs[n]; + if (IsCall(node_def)) { + call_nodes.insert(node_def.name()); + } + } + + return call_nodes; +} + Status GraphConstructor::InitFromEdges() { const int num_nodes = node_defs_.size(); pending_count_.reserve(num_nodes); outputs_.resize(num_nodes); std::unordered_set next_iteration_nodes_ = GetNextIterationNodes(node_defs_); + std::unordered_set call_nodes_ = + GetCallNodes(node_defs_); // Parse the inputs for each node. for (int n = 0; n < num_nodes; ++n) { const NodeDef& node_def = *node_defs_[n]; - if (IsMerge(node_def)) { - // Cycles in the graph are only allowed for while loops. A while loop is - // identified by an edge from a NextIteration node to a Merge node. For - // such Merge nodes, only wait for one non-control input before - // considering the node ready to process in Convert(). + + if (IsReturningNode(node_def)) { + int32 num_control_edges = 0; + for (int i = 0; i < node_def.input_size(); ++i) { + if (StringPiece(node_def.input(i)).starts_with("^")) { + num_control_edges++; + } + } + pending_count_.push_back(num_control_edges + 1); + + } else if (IsMerge(node_def)) { + // Cycles in the graph are only allowed for while loops and recursion. + // A while loop is identified by an edge from a NextIteration node to a Merge node. + // A recursion is identified by an edge from a Call Node to a Merge node + // In recursion, function returning nodes also participate in a cycle + // For such Merge nodes, and for function returning nodes only wait for + // one non-control input before considering the node ready to process in Convert(). int32 num_control_edges = 0; bool has_loop_back_edge = false; for (int i = 0; i < node_def.input_size(); ++i) { @@ -422,7 +510,9 @@ Status GraphConstructor::InitFromEdges() { } else { TensorId id(ParseTensorName(input_name)); if (next_iteration_nodes_.find(id.first.ToString()) != - next_iteration_nodes_.end()) { + next_iteration_nodes_.end() || + call_nodes_.find(id.first.ToString()) != + call_nodes_.end()) { has_loop_back_edge = true; } } @@ -807,10 +897,10 @@ Status GraphConstructor::Convert() { inputs.push_back(InputInfo(id.first.ToString(), src_node, src_index)); } - if (has_data_back_edge && !IsMerge(*node_def)) { + if (has_data_back_edge && !IsMerge(*node_def) && !IsReturningNode(*node_def)) { return errors::InvalidArgument( "Node '", node_def->name(), - "' had a back edge, but only Merge nodes can have back edges."); + "' had a back edge, but only Merge and returning nodes can have back edges."); } Node* node; diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 71d8cdd6ab..2744be9077 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -22,6 +22,10 @@ limitations under the License. #include #include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" +#include "tensorflow/core/graph/graph_constructor.h" + #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -38,6 +42,9 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/device_name_utils.h" + +#include "tensorflow/core/framework/graph_def_util.h" + namespace tensorflow { namespace { @@ -931,20 +938,681 @@ void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) { } } + +/**************************************************************************************************/ + +struct StateMachineNodeInput { + string src; + int index; +}; + +struct StateMachineParent { + Node* parent_node; + int parent_index; +}; + +struct StateMachineNode { + Node* node; + std::vector inputs; +}; + +struct StateMachineGraph { + std::unordered_map nodes; + std::set depends_on; + Node* merge; +}; + +struct StateMachine { + // A map from unique_ids to StateMachineGraphs representing a general dynamic + // state machine that we update every time a function gets called, and helps us + // gradually build the state machines of the partitions + std::unordered_map state_machine_graphs; + // state_machine_parents is the 'spine' of the graph, + // containing only control flow nodes + std::vector state_machine_parents; + + std::unordered_map switches_info; + // + std::unordered_map switchToPred; + + string leader_partition; + + // Maps device names to smaller strings + std::unordered_map device_names_map; + + std::unordered_map*> partitionsToSMG; +}; + +struct FuncInfo { + // A map from to the num of function's arguments + std::unordered_map funcInputs; + // Helps us seperate functions with same frame_name but + // different non recursive call sites + std::unordered_map funcVisitedCounter; + // Εach vector below operates as a barrier, + // we don't call CallingFunction(..) before we gather + // all function's arguments/calls first + std::unordered_map*> funcCalls; +}; + +// Adds root nodes into ready_nodes queue and sets ready_inputs appropriately +void PreprocessGraph(std::unordered_map &ready_inputs, Graph* g, + std::deque &ready_nodes) { + + std::unordered_map> returning_nodes; + + for (Node* node : g->nodes()) { + + if (node->in_edges().empty()) { + ready_nodes.push_back(node); + } + bool recursion_merge = 0; + if (IsMerge(node)) { + ready_inputs[node] = 0; + for (const Edge* in_edge : node->in_edges()) { + + Node* in = in_edge->src(); + // if (IsNextIteration(*output_map.GetNode(input))) { + // ready_inputs[node]++; + // } + if (IsCall(in)) { + ready_inputs[node]++; + recursion_merge = 1; + } + } + if (recursion_merge) { + ready_inputs[node]--; + recursion_merge = 0; + } + + } else if (IsReturn(node)) { + + for (const Edge* in_edge : node->in_edges()) { + Node* in = in_edge->src(); + + if (!in_edge->IsControlEdge()) { + int call_id; + GetNodeAttr(node->attrs(), "call_id", &call_id); + returning_nodes[in].emplace(call_id); + } + } + ready_inputs[node] = 0; + + } else { + ready_inputs[node] = 0; + } + } + + for (const auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + ready_inputs[retnode.first]++; + } + } +} + +string GetDeviceMappedName(StateMachine &state_machine, string device_name) { + + std::unordered_map& device_map = state_machine.device_names_map; + + auto slot = &device_map[device_name]; + if (*slot == "") + *slot = strings::StrCat("_p", device_map.size() + 1); + return *slot; +} + +bool IsCallSuccessor(Node* node) { + + for (const Edge* in_edge : node->in_edges()) { + Node* src = in_edge->src(); + if (IsCall(src) && !in_edge->IsControlEdge()) + return true; + } + return false; +} + +void DeleteStateMachineGraph(StateMachine& state_machine, string unique_id) { + + StateMachineGraph *smg = state_machine.state_machine_graphs[unique_id]; + + for (auto& it : smg->nodes) + delete it.second; + delete smg; +} + +std::vector* GetOrCreateCalls(int call_id, std::unordered_map*> &funcCalls) { + auto slot = &funcCalls[call_id]; + if (*slot == nullptr) + *slot = new std::vector; + return *slot; +} + +std::set* GetOrCreatePartition(string partition, std::unordered_map*> &partsTpSmg) { + auto slot = &partsTpSmg[partition]; + if (*slot == nullptr) + *slot = new std::set; + return *slot; +} + +// For one if-else construction there are more than one Switch nodes guarding all the inputs +// that are needed inside the branches but live outside of them. We need to collect all the Switch +// nodes that correspond to one if-else construction and treat them as one in the state machines +// switches_info: Every switch node maps to the original switch that we "ll take into account +void CollectSwitches(Graph* g, StateMachine& state_machine) { + + std::unordered_map pred_switch; + + for (Node *node : g->nodes()) { + + if (IsSwitch(node)) { + + for (const Edge *in_edge : node->in_edges()) { + + int port = in_edge->dst_input(); + + // A sloppy way to determine if this is the predicate input + if (!in_edge->IsControlEdge() && port == 1) { + + Node *predicate = in_edge->src(); + + while (IsIdentity(predicate)) { + for (const Edge *inEdge : predicate->in_edges()) { + if (!inEdge->IsControlEdge()) { + predicate = inEdge->src(); + break; + } + } + } + + // We 've got the real predicate + Node *switchNode; + if (pred_switch.find(predicate) == pred_switch.end()) { + // Original switch + pred_switch[predicate] = node; + state_machine.switchToPred[node] = predicate; + switchNode = node; + } else { + // "Synonym" switch + switchNode = pred_switch[predicate]; + } + + state_machine.switches_info[node] = switchNode; + + break; + } + } + printf("Switch : %s -> %s\n", node->name().c_str(), state_machine.switches_info[node]->name().c_str()); + } + } + + printf("\n\n\n"); +} + +void GatherPartitionStateMachines(StateMachine& state_machine, std::set* smgs) { + + std::deque queue; + + for (auto& it : *smgs) + queue.push_back(it); + + while (!queue.empty()) { + string smg = queue.front(); + queue.pop_front(); + + StateMachineGraph* sm_graph = state_machine.state_machine_graphs[smg]; + for (auto& it : sm_graph->depends_on) { + // If not already visited + if (smgs->find(it) == smgs->end()) { + smgs->emplace(it); + queue.push_back(it); + } + } + } +} + +NodeDef* FindNodeInGraphDef(GraphDef& graphDef, string node_name) { + + for (NodeDef& nodeDef : *graphDef.mutable_node()) { + if (nodeDef.name() == node_name) + return &nodeDef; + } + return nullptr; +} + +void ConnectMergeToNode(GraphDef& graphDef, string merge_name, string node_name, + StateMachine& state_machine, string partition_name) { + + // We can safely infer the correct Merge's name and add it as control input to the node + // even though partition state machine's Merge has not already been added into graphdef + string suffix; + (partition_name != state_machine.leader_partition) ? + (suffix = GetDeviceMappedName(state_machine, partition_name)) : (suffix = ""); + + //Add as control input + NodeDef* node = FindNodeInGraphDef(graphDef, node_name); + *node->add_input() = strings::StrCat("^", merge_name, suffix); +} + +void AddPartitionStateMachine(StateMachine& state_machine, GraphDef& main_graphDef, + string unique_id, string partition) { + + StateMachineGraph *sm_graph = state_machine.state_machine_graphs[unique_id]; + string suffix = GetDeviceMappedName(state_machine, partition); + for (const auto &it : sm_graph->nodes) { + string node_name = it.first; + StateMachineNode *sm_node = it.second; + Node *node = sm_node->node; + + // Build NodeDef + NodeDef *nodedef = main_graphDef.add_node(); + //Note: suffix does not guarantee that name is unique + nodedef->set_name(strings::StrCat(node_name, suffix)); + nodedef->set_op(node->op_def().name()); + nodedef->set_device(partition); + + // Add Inputs + for (int i = 0; i < sm_node->inputs.size(); ++i) { + // There won't exist any control inputs here + nodedef->add_input(strings::StrCat(sm_node->inputs[i].src, suffix, ":", sm_node->inputs[i].index)); + + if (StringPiece(sm_node->inputs[i].src).starts_with("Dummy_")) { + Tensor tensor(DT_INT32, TensorShape({0})); + NodeDef* dummy = main_graphDef.add_node(); + dummy->set_name(strings::StrCat(sm_node->inputs[i].src, suffix)); + dummy->set_op("Const"); + dummy->set_device(partition); + AddNodeAttr("dtype", DT_INT32, dummy); + AddNodeAttr("value", tensor, dummy); + } + } + + if (IsSwitch(node)) { + // Add predicate input too + nodedef->add_input(state_machine.switchToPred[node]->name()); + // Add control input from partition's Merge to partition's Switch + nodedef->add_input(strings::StrCat("^", sm_graph->merge->name(), suffix)); + } + + for (const auto &itt : node->def().attr()) { + // Not sure if this is copying attrs correctly + if (itt.first == "T") { + // We don't care about keeping the original "T" attr + // in state machine nodes + AddNodeAttr(itt.first, DT_INT32, nodedef); + } else + AddNodeAttr(itt.first, itt.second, nodedef); + } + } +} + +void AddNodeToStateMachine(StateMachine& state_machine, string unique_id, Node* node, bool cycle) { + + StateMachineGraph *smg = state_machine.state_machine_graphs[unique_id]; + StateMachineNode *smn = new StateMachineNode; + + smn->node = node; + + StateMachineParent *parent = &state_machine.state_machine_parents[node->id()]; + + if (parent->parent_node == nullptr) { + int call_id; + GetNodeAttr(node->attrs(), "call_id", &call_id); + smn->inputs.push_back({strings::StrCat("Dummy_", call_id), 0}); + } else + smn->inputs.push_back({parent->parent_node->name(), parent->parent_index}); + + smg->nodes[node->name()] = smn; + + // If cycle is true, node is a recursive call, that needs to be added as + // input to the corresponding Merge node + if (cycle) { + // We traverse graph the way topological sort does, so we will never + // meet a recursive call node before its corresponding Merge + StateMachineNode* merge = smg->nodes[smg->merge->name()]; + merge->inputs.push_back({node->name(), 0}); + } +} + +void CallingFunction(Graph* graph, GraphDef& main_graphDef, StateMachine& state_machine, FuncInfo& funcInfo, + string function_frame_name, int function_call_id, + std::unordered_map& ready_inputs, + std::deque& prev_ready_nodes) { + + Node *merge, *call; + std::deque ready_nodes; + + string function_unique_id = strings::StrCat(function_frame_name, ":", + funcInfo.funcVisitedCounter[function_frame_name]); + + std::vector* calls = funcInfo.funcCalls[function_call_id]; + for (int i=0; i < calls->size(); ++i) { + ready_nodes.push_back((*calls)[i]); + } + call = (*calls)[0]; + + // We add only one Call node for all possible function's args in the state machine + AddNodeToStateMachine(state_machine, function_unique_id, call, false); + + std::vector& state_machine_parents = state_machine.state_machine_parents; + StateMachineGraph* sm_graph = state_machine.state_machine_graphs[function_unique_id]; + + // Call's successor (the non control output) will be either + // a Merge node (in case of recursion) or an Identity node. + // Either way we add that successor to the state machine, too. + // Same as above, we add only one Merge node instead of one per function's arg + for (const Edge* out_edge : call->out_edges()) { + if (!out_edge->IsControlEdge()) { + merge = out_edge->dst(); + state_machine_parents[merge->id()].parent_node = call; + state_machine_parents[merge->id()].parent_index = 0; + AddNodeToStateMachine(state_machine, function_unique_id, merge, false); + sm_graph->merge = merge; + break; + } + } + + while (!ready_nodes.empty()) { + + Node* ready_node = ready_nodes.front(); + ready_nodes.pop_front(); + + int parent_index = 0; + Node* parent = state_machine_parents[ready_node->id()].parent_node; + + // The ops below need to update the parent + if (IsCall(ready_node)) { + parent = call; + } else if (IsCallSuccessor(ready_node)) { + parent = merge; + } else if (IsSwitch(ready_node)) { + Node *sw = state_machine.switches_info[ready_node]; + if (sw == ready_node) + AddNodeToStateMachine(state_machine, function_unique_id, ready_node, false); + parent = sw; + } else if (IsMerge(ready_node)) { + // Control Flow (regular) Merge has a corresponding Switch node + // Parent gets the value of that switch node's parent + parent = state_machine_parents[parent->id()].parent_node; + parent_index = state_machine_parents[parent->id()].parent_index; + } else if (IsReturn(ready_node)) { + // Return needs to propagate its corresponding Call's parent to all its successors + for (const Edge* in_edge : ready_node->in_edges()) { + if (in_edge->IsControlEdge()) { + Node* call_node = in_edge->src(); + parent = state_machine_parents[call_node->id()].parent_node; + parent_index = state_machine_parents[call_node->id()].parent_index; + break; + } + } + int call_id; + GetNodeAttr(ready_node->attrs(), "call_id", &call_id); + // If not a 'recursive' return + if (call_id == function_call_id) { + // Add the successors of Return node to prev_ready_nodes queue + prev_ready_nodes.push_back(ready_node); + // Set the parent value of the only actual output of return + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + state_machine_parents[out->id()].parent_node = parent; + state_machine_parents[out->id()].parent_index = parent_index; + break; + } + continue; + } + } + + // Process ready_node's outputs + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + + ready_inputs[out]++; + + // For a cross-device edge, on the dst device, add a control edge + // from the merge node of the state machine to dst. If a send/recv is + // introduced for this edge in future partitioning, we delete this + // control edge and add a new control edge from the merge to the recv. + const string& src_device = ready_node->assigned_device_name(); + const string& dst_device = out->assigned_device_name(); + if (src_device != dst_device) { + if (IsCallSuccessor(ready_node) && IsConstant(out)) { + // Remove this control edge that ensures constant executes in the same frame, + // and add a new one from the Constant's partition's state machine merge to the constant + NodeDef* con_node = FindNodeInGraphDef(main_graphDef, out->name()); + for (string& input : *con_node->mutable_input()) { + if (StringPiece(input).starts_with(strings::StrCat("^", ready_node->name()))) { + string suffix = GetDeviceMappedName(state_machine, dst_device); + input = strings::StrCat("^", merge->name(), suffix); + break; + } + } + } else + ConnectMergeToNode(main_graphDef, merge->name(), out->name(), state_machine, dst_device); + } + + if (ready_inputs[out] == out->in_edges().size()) { + + if (IsSwitch(ready_node)) { + // We need to fix parent_index appropriately + parent_index = out_edge->src_output(); + } + + // Set node's parent + state_machine_parents[out->id()].parent_node = parent; + state_machine_parents[out->id()].parent_index = parent_index; + + std::unordered_map& sm_graphs = state_machine.state_machine_graphs; + + if (IsCall(out)) { + + string frame_name; + GetNodeAttr(out->attrs(), "frame_name", &frame_name); + int call_id; + GetNodeAttr(out->attrs(), "call_id", &call_id); + + std::vector* calls = GetOrCreateCalls(call_id, funcInfo.funcCalls); + calls->push_back(out); + + if (funcInfo.funcInputs[frame_name] == calls->size()) { + + // We gathered all function's inputs + + string unique_id = strings::StrCat(frame_name, ":", funcInfo.funcVisitedCounter[frame_name]); + + if (sm_graphs.find(unique_id) == sm_graphs.end()) { + + sm_graphs.emplace(unique_id, new StateMachineGraph); + CallingFunction(graph, main_graphDef, state_machine, funcInfo, frame_name, call_id, ready_inputs, ready_nodes); + funcInfo.funcVisitedCounter[frame_name]++; + } else { + // Recursive Call (either to the same function or another one (mutual recursion) + AddNodeToStateMachine(state_machine, unique_id, (*calls)[0], true); + // Add the recursive call nodes to ready_nodes + for (int i=0; i < calls->size(); ++i) + ready_nodes.push_back((*calls)[i]); + } + + sm_graphs[unique_id]->depends_on.emplace(function_unique_id); + } + } else { + GetOrCreatePartition(dst_device, state_machine.partitionsToSMG)->emplace(function_unique_id); + ready_nodes.push_back(out); + } + } + } + } +} + +Status AddFunctionStateMachines(const PartitionOptions& opts, + Graph* g, GraphDef& main_graphDef, GraphInfo* g_info) { + + Status status; + GraphDefBuilder::Options bopts(g, &status); + + FuncInfo funcInfo; + int nodes_num = g->num_node_ids(); + + const FunctionDefLibrary& fdef = opts.flib_def->ToProto(); + for (const FunctionDef& func : fdef.function()) { + + int num_inputs = func.signature().input_arg_size(); + string name = func.signature().name(); + funcInfo.funcInputs[name] = num_inputs; + funcInfo.funcVisitedCounter[name] = 0; + } + + StateMachine state_machine; + state_machine.state_machine_parents.resize(nodes_num); + + CollectSwitches(g, state_machine); + + // Add all state machines for cross-device frames. + // A state machine is added only when there is a cross-device edge in a + // non-root frame. + + // Visit nodes the way topological sort does + std::deque ready_nodes; + std::unordered_map ready_inputs; + + PreprocessGraph(ready_inputs, g, ready_nodes); + + // We convert graph to its equivalent graph_def, cause it's easier + // to extend it with the GraphDef state machines of partitions + g->ToGraphDef(&main_graphDef); + + while (!ready_nodes.empty()) { + Node* ready_node = ready_nodes.front(); + ready_nodes.pop_front(); + + for (const Edge* out_edge : ready_node->out_edges()) { + Node* out = out_edge->dst(); + + ready_inputs[out]++; + + if (ready_inputs[out] == out->in_edges().size()) { + + if (IsCall(out)) { + string frame_name; + GetNodeAttr(out->attrs(), "frame_name", &frame_name); + int call_id; + GetNodeAttr(out->attrs(), "call_id", &call_id); + + std::vector* calls = GetOrCreateCalls(call_id, funcInfo.funcCalls); + calls->push_back(out); + + if (funcInfo.funcInputs[frame_name] == calls->size()) { + + string unique_id = strings::StrCat(frame_name, ":", funcInfo.funcVisitedCounter[frame_name]); + + // We gathered all function's inputs + state_machine.leader_partition = out->assigned_device_name(); + state_machine.state_machine_graphs.emplace(unique_id, new StateMachineGraph); + CallingFunction(g, main_graphDef, state_machine, funcInfo, frame_name, call_id, ready_inputs, ready_nodes); + funcInfo.funcVisitedCounter[frame_name]++; + + // Adding partition state machines to graph + for (auto& it: state_machine.partitionsToSMG) { + string partition = it.first; + + // Leader Partition already has its state machine + if (partition == state_machine.leader_partition) + continue; + + std::set* smgs = it.second; + + // Collect all the state machine graphs that smgs depened on + GatherPartitionStateMachines(state_machine, smgs); + + for (auto& it : *smgs) + AddPartitionStateMachine(state_machine, main_graphDef, it, partition); + } + + // Deallocate space + for (auto& it : state_machine.partitionsToSMG) + delete it.second; + state_machine.partitionsToSMG.clear(); + + for (auto& it: state_machine.state_machine_graphs) + DeleteStateMachineGraph(state_machine, it.first); + state_machine.state_machine_graphs.clear(); + } + } else + ready_nodes.push_back(out); + } + } + } + + // Deallocate space + for (auto& it : funcInfo.funcCalls) + delete it.second; + +/****************************************************************************/ + printf("\n\nSummarize Main Graph\n %s\n", SummarizeGraphDef(main_graphDef).c_str()); + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("Full_Partitioned"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = main_graphDef.ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + main_graphDef.GetTypeName(), "' and size ", proto_size); + } + main_graphDef.SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); +/****************************************************************************/ + + return Status::OK(); +} + + + +/**************************************************************************************************/ + + Status Partition(const PartitionOptions& opts, Graph* g, std::unordered_map* partitions) { Status status; partitions->clear(); GraphInfo g_info; + std::unique_ptr new_g(new Graph(OpRegistry::Global())); + if (!opts.control_flow_added) { // Add the "code" for distributed execution of control flow. Code is // added only for the frames that are placed on multiple devices. The // new graph is an equivalent transformation of the original graph and // has the property that it can be subsequently partitioned arbitrarily // (down to the level of individual device) for distributed execution. - status = AddControlFlow(opts, g, &g_info); + GraphDef main_graphDef; + g->ToGraphDef(&main_graphDef); + printf("\n\nSummarize Main Graph:\n %s\n\n", SummarizeGraphDef(main_graphDef).c_str()); + + status = AddControlFlow(opts, g, &g_info); if (!status.ok()) return status; + + GraphDef gdef; + status = AddFunctionStateMachines(opts, g, gdef, &g_info); + if (status.ok()) { + // Convert GraphDef back to Graph so it can be partitioned + GraphConstructorOptions gopts; + gopts.allow_internal_ops = true; + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(gopts, gdef, new_g.get())); + g = new_g.get(); + + // The graph conversion sets the requested device names but not the assigned + // device names. However, since at this point the graph is placed TF expects + // an assigned device name for every node. Therefore we copy the requested + // device into the assigned device field. + for (Node* node : g->nodes()) { + node->set_assigned_device_name(node->requested_device()); + } + } else return status; } // At this point, all the graph mutations have been done. Build memory @@ -994,7 +1662,19 @@ Status Partition(const PartitionOptions& opts, Graph* g, int32 num_input_edges = 0; for (const Edge* edge : dst->in_edges()) { if (edge->IsControlEdge()) { - if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { + if ((IsMerge(edge->src()) && IsControlLoop(edge->src())) || + (IsCallSuccessor(edge->src()) && (!IsConstant(edge->dst()) || + edge->dst()->in_edges().size() > 1))) { + // Note: not all control edges are control flow edges. + // There are also control edges added in + // FunctionTransformation for ensuring that Constants will execute in the + // correct 'frame'. + // We made sure in AddFunctionsStateMachines that: + // if a Constant in partition A, has such incoming edge from a CallSuccessor(..) + // node, then this node will definitely belong in the same A partition, so we + // can safely add those edges in "inputs" as we do with common control edges. + // All the other edges whose src node is a CallSuccessor node are control flow edges. + // This is one of the control edges added for control flow. There // can be multiple such edges as the dest node may have multiple // remote inputs. We keep track of the number of such edges. @@ -1102,7 +1782,7 @@ Status Partition(const PartitionOptions& opts, Graph* g, NodeDef* real_recv = nullptr; NodeDef* recv = - AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status); + AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status); if (!status.ok()) return status; // Fix up the control flow edge. diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index acb8498142..6b41f952b9 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -50,6 +50,16 @@ bool IsExit(const NodeDef& node) { return op == "Exit" || op == "RefExit"; } +bool IsCall(const NodeDef& node) { + const auto& op = node.op(); + return op == "Call" || op == "RefCall"; +} + +bool IsReturn(const NodeDef& node) { + const auto& op = node.op(); + return op == "Return" || op == "RefReturn"; +} + bool IsIdentity(const NodeDef& node) { const auto& op = node.op(); return op == "Identity" || op == "RefIdentity"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 0de954fcb4..6feab5bb3d 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -27,6 +27,8 @@ bool IsConstant(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsExit(const NodeDef& node); +bool IsCall(const NodeDef& node); +bool IsReturn(const NodeDef& node); bool IsIdentity(const NodeDef& node); bool IsMerge(const NodeDef& node); bool IsNextIteration(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index c4def6cf23..cb3ca9d988 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -126,6 +126,25 @@ tf_cc_test( ], ) +cc_library( + name = "function_transformation", + srcs = ["function_transformation.cc"], + hdrs = [ + "function_transformation.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:functions", + ], +) + cc_library( name = "graph_rewriter", srcs = ["graph_rewriter.cc"], @@ -304,6 +323,7 @@ cc_library( ":arithmetic_optimizer", ":auto_parallel", ":constant_folding", + ":function_transformation", ":graph_optimizer", ":layout_optimizer", ":memory_optimizer", diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index faea843c69..7885facac7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -251,9 +251,10 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (op == "Const") { return false; } - // Skip constrol flow nodes, they can't be folded + // Skip control flow nodes, they can't be folded if (op == "Enter" || op == "RefEnter" || op == "Exit" || op == "RefExit" || - op == "NextIteration" || op == "RefNextIteration") { + op == "NextIteration" || op == "RefNextIteration" || + op == "Call" || op == "RefCall" || op == "Return" || op == "RefReturn") { return false; } if (op.find("Placeholder") == 0) { @@ -283,7 +284,6 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (op_def->output_arg_size() == 0) { return false; } - // No need to (and don't) fold nodes that have no outgoing edges except // whitelisted nodes. Such nodes could be introduced by an earlier constant // folding pass and are preserved in case users want to fetch their values; diff --git a/tensorflow/core/grappler/optimizers/function_transformation.cc b/tensorflow/core/grappler/optimizers/function_transformation.cc new file mode 100644 index 0000000000..46dd00825a --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_transformation.cc @@ -0,0 +1,498 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_transformation.h" +#include +#include +#include +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" + +namespace tensorflow { +namespace grappler { +namespace { + +typedef std::unordered_map ArgMergeMap; + +typedef struct { + ArgMergeMap argMergeMap; + gtl::ArraySlice fetch; +} FuncInfo; + +// same with commit b691c0 (possibly) +class FunctionInliningContext { + public: + explicit FunctionInliningContext(const GrapplerItem& item) + : library_(&item.graph.library()), functions_(InliningCandidates(item)) {} + + const FunctionDefLibrary& Library() const { return *library_; } + + bool HasInlinedFunctions() const { return !functions_.empty(); } + + // Find inlining candidate by name. Return nullptr if not found. + const FunctionDef* FindInlinedFunction(const string& name) const { + auto it = functions_.find(name); + if (it != functions_.end()) { + return it->second; + } else { + return nullptr; + } + } + + private: + std::unordered_map InliningCandidates(const GrapplerItem& item) const { + std::unordered_map functions; + for (const FunctionDef& func : item.graph.library().function()) { + // Don't inline functions marked as noinline +// if (func.attr().count("_noinline") != 0) { +// continue; +// } + // Don't touch anything marked XLA to prevent XLA failures further down + // the road. + if (func.attr().count("_XlaCompile") > 0 && + func.attr().at("_XlaCompile").b()) { + continue; + } + // Can't create IdentityN nodes with no input or output: skip these + // functions for now. + if (func.signature().input_arg_size() == 0 || + func.signature().output_arg_size() == 0) { + continue; + } + functions[func.signature().name()] = &func; + } + return functions; + } + + const FunctionDefLibrary* library_; + std::unordered_map functions_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext); +}; + +// Copy input/output argument type to the type. Return error if argument +// type is not explicitly defined, and not specified in function attributes. +Status CopyArgType(const NodeDef& func_node, + const std::unordered_map& func_attr, + const string& arg_kind, const OpDef::ArgDef& arg, + DataType* type) { + if (arg.type() != DT_INVALID) { + *type = arg.type(); + } else { + auto it = func_attr.find(arg.type_attr()); + if (it == func_attr.end() || it->second.type() == DT_INVALID) { + return errors::InvalidArgument( + "Invalid ", arg_kind, " argument ", arg.name(), " for function ", + func_node.op(), " instantiated by ", func_node.name()); + } + *type = it->second.type(); + } + return Status::OK(); +} + +// Copy input/output argument type to the type_list. Return error if argument +// type is not explicitly defined, and not specified in function attributes. +Status CopyArgType(const NodeDef& func_node, + const std::unordered_map& func_attr, + const string& arg_kind, const OpDef::ArgDef& arg, + AttrValue::ListValue* type_list) { + if (arg.type() != DT_INVALID) { + type_list->add_type(arg.type()); + } else { + auto it = func_attr.find(arg.type_attr()); + if (it == func_attr.end() || it->second.type() == DT_INVALID) { + return errors::InvalidArgument( + "Invalid ", arg_kind, " argument ", arg.name(), " for function ", + func_node.op(), " instantiated by ", func_node.name()); + } + type_list->add_type(it->second.type()); + } + return Status::OK(); +} + +string ParseString(string input) { + size_t pos = 0; + std::string res = ""; + std::string delimiter = ":"; + + if ((pos = input.find(delimiter)) != std::string::npos) { + res = res + input.substr(0, pos); + input.erase(0, pos + delimiter.length()); + res = res + "/Ret" + input; + } + else { + res = input + "/Ret0"; + } + return res; +} + +Status GatherOutputs(const GrapplerItem& item, const FunctionInliningContext& ctx, + std::set &foutputs) { + for (const NodeDef& node : item.graph.node()) { + const FunctionDef* func = ctx.FindInlinedFunction(node.op()); + if (func != nullptr) { // If it's a function calling node + for (int i = 0; i < func->signature().output_arg_size(); ++i) { + // const OpDef::ArgDef &arg = func->signature().output_arg(i); + foutputs.emplace(node.name()); // Fac + foutputs.emplace(strings::StrCat(node.name(), ":", i)); // Fac:i + //foutputs.emplace(strings::StrCat(node.name(), ":", arg.name(), ":", i)); // Fac:outarg:i + } + } + } + return Status::OK(); +} + + +Status CreateCycle(NodeDef& func_node, const FunctionDef& func, GraphDef* optimized_graph, + std::unordered_map &functions_in, int call_id, string device) { + const std::unordered_map func_attr(func_node.attr().begin(), func_node.attr().end()); + + DataType type; + ArgMergeMap& argmerge_map = functions_in[func_node.op()].argMergeMap; + + NodeDef *call; + for (int i = 0; i < func.signature().input_arg_size(); ++i) { + const OpDef::ArgDef &arg = func.signature().input_arg(i); + + // Create and add in graph a Call node for every input arg + call = optimized_graph->add_node(); + call->set_name(strings::StrCat(func_node.name(), "/", "Call_", i)); + call->set_op("Call"); + call->set_device(device); + call->add_input(func_node.input(i)); + TF_RETURN_IF_ERROR(CopyArgType(func_node, func_attr, "input", arg, &type)); + (*call->mutable_attr())["T"].set_type(type); + (*call->mutable_attr())["frame_name"].set_s(func_node.op()); + (*call->mutable_attr())["call_id"].set_i(call_id); + (*call->mutable_attr())["arg_id"].set_i(i); + (*call->mutable_attr())["is_constant"].set_b(false); + + NodeDef* merge = argmerge_map[arg.name()]; + merge->add_input(call->name()); + } + + for (int i = 0; i < func.signature().output_arg_size(); ++i) { + const OpDef::ArgDef &arg = func.signature().output_arg(i); + + NodeDef *ret = optimized_graph->add_node(); + ret->set_name(strings::StrCat(func_node.name(), "/", "Ret", i)); + ret->set_op("Return"); + ret->set_device(device); + // Counting on the fact that op name will be the same as the name given initially to function + ret->add_input(strings::StrCat(func_node.op(), "/", functions_in[func_node.op()].fetch[i])); + TF_RETURN_IF_ERROR(CopyArgType(func_node, func_attr, "output", arg, &type)); + (*ret->mutable_attr())["T"].set_type(type); + (*ret->mutable_attr())["frame_name"].set_s(func_node.op()); + (*ret->mutable_attr())["call_id"].set_i(call_id); + (*ret->mutable_attr())["arg_id"].set_i(i); + + // Add a control input from Call to Returns + *ret->add_input() = AsControlDependency(call->name()); + } + return Status::OK(); +} + + +Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, + const FunctionInliningContext& ctx, + GraphDef* optimized_graph, + std::unordered_map &functions_in, + int& frame_name, string device) { + + int cpframe_name = frame_name; + + const std::unordered_map func_attr(func_node.attr().begin(), func_node.attr().end()); + std::unique_ptr item = GrapplerItemFromFunctionDef(func, func_attr, ctx.Library()); + + if (!item) { + return errors::InvalidArgument( + "Failed to inline function ", func_node.op(), + " instantiated by ", func_node.name()); + } + + std::set foutputs; + GatherOutputs(*item, ctx, foutputs); + + DataType type; + std::unordered_map input_nodes; + functions_in[func_node.op()].fetch = item->fetch; + ArgMergeMap& argmerge_map = functions_in[func_node.op()].argMergeMap; + + NodeDef* call; + for (int i = 0; i < func.signature().input_arg_size(); ++i) { + const OpDef::ArgDef& arg = func.signature().input_arg(i); + + input_nodes[arg.name()] = i; + + // Create and add in graph a Call node for every input arg + call = optimized_graph->add_node(); + call->set_name(strings::StrCat(func_node.name(), "/", "Call_", i)); + call->set_op("Call"); + call->set_device(device); + call->add_input(func_node.input(i)); + TF_RETURN_IF_ERROR(CopyArgType(func_node, func_attr, "input", arg, &type)); + (*call->mutable_attr())["T"].set_type(type); + (*call->mutable_attr())["frame_name"].set_s(func_node.op()); + (*call->mutable_attr())["call_id"].set_i(frame_name); + (*call->mutable_attr())["arg_id"].set_i(i); + (*call->mutable_attr())["is_constant"].set_b(false); + + // Create and add a temporary merge node (IdentityN) for every input arg + NodeDef* merge = optimized_graph->add_node(); + merge->set_name(strings::StrCat(func_node.name(), "/", "Merge_", i)); + merge->set_op("IdentityN"); + merge->set_device(device); + merge->add_input(call->name()); + + argmerge_map.emplace(arg.name(), merge); + } + + for (NodeDef& func_body_node : *item->graph.mutable_node()) { + // If the func body node is func's input argument + if (input_nodes.find(func_body_node.name()) != input_nodes.end()) { + CHECK_EQ(0, func_body_node.input_size()); + // Turn input placeholders into identity nodes + if (IsPlaceholder(func_body_node)) { + func_body_node.set_op("Identity"); + } + // Connect merge with input arg + func_body_node.add_input(argmerge_map[func_body_node.name()]->name()); + } else { // Else if not an input_arg_node + // Update the input names if any. + for (string& input : *func_body_node.mutable_input()) { + + // If it takes input from a function + if (foutputs.find(input) != foutputs.end()) { + input = ParseString(input); + } + input = AddPrefixToNodeName(input, /*prefix=*/func_node.name()); + } + // If the node has no input, hook it up to the Merge nodes to ensure + // it runs in the same frame as the other nodes of the function body. + if (func_body_node.input_size() == 0) { + for (auto it = argmerge_map.begin(); it != argmerge_map.end(); ++it) { + *func_body_node.add_input() = AsControlDependency(it->second->name()); + } + } + } + + // Add the node name as a prefix to avoid collisions after inlining + func_body_node.set_name(strings::StrCat(func_node.name(), "/", func_body_node.name())); + + // Make sure the node is placed + string dvc = func_body_node.device(); + (dvc == "") ? (func_body_node.set_device(device)) : (func_body_node.set_device(dvc)); + + // Check if a body node is itself a function + const FunctionDef* func_body_node_func = ctx.FindInlinedFunction(func_body_node.op()); + + // Node is yet another function + if (func_body_node_func != nullptr) { + + // Check if that function has already been inlined + auto it = functions_in.find(func_body_node.op()); + + // Not already in => Inline it + if (it == functions_in.end()) { + FuncInfo func_info; + functions_in.emplace(func_body_node.op(), func_info); + InlineFunction(func_body_node, *func_body_node_func, ctx, optimized_graph, functions_in, ++frame_name, device); + functions_in.erase(func_body_node.op()); + } else { + // Already in -> Insert Enter/Exit ops end create cycle + // (recursion or mutually recursive functions) + CreateCycle(func_body_node, *func_body_node_func, optimized_graph, functions_in, ++frame_name, device); + } + } else { + // Move the node to the main graph + optimized_graph->add_node()->Swap(&func_body_node); + } + } + + for (int i = 0; i < func.signature().output_arg_size(); ++i) { + const OpDef::ArgDef &arg = func.signature().output_arg(i); + + NodeDef *ret = optimized_graph->add_node(); + ret->set_name(strings::StrCat(func_node.name(), "/", "Ret", i)); + ret->set_op("Return"); + ret->set_device(device); + // If it takes input from a function + string input = item->fetch[i]; + if (foutputs.find(input) != foutputs.end()) { + input = ParseString(input); + } + + ret->add_input(strings::StrCat(func_node.name(), "/", input)); + TF_RETURN_IF_ERROR(CopyArgType(func_node, func_attr, "output", arg, &type)); + (*ret->mutable_attr())["T"].set_type(type); + (*ret->mutable_attr())["frame_name"].set_s(func_node.op()); + (*ret->mutable_attr())["call_id"].set_i(cpframe_name); + (*ret->mutable_attr())["arg_id"].set_i(i); + + // Add a control input from Call to Returns + *ret->add_input() = AsControlDependency(call->name()); + } + + int j=0; + for (auto it = argmerge_map.begin(); it != argmerge_map.end(); ++it, ++j) { + DataType type; + NodeDef *new_merge, *merge = it->second; + int i, size = merge->input_size(); + + TF_RETURN_IF_ERROR(CopyArgType(func_node, func_attr, + "input", func.signature().input_arg(j), &type)); + + if (size <= 1) { + merge->set_op("Identity"); + merge->set_device(device); + (*merge->mutable_attr())["T"].set_type(type); + } else { + merge->set_op("Merge"); + merge->set_device(func_node.device()); + (*merge->mutable_attr())["T"].set_type(type); + (*merge->mutable_attr())["N"].set_i(size); + } + } + + return Status::OK(); +} + +} // namespace + +Status FunctionTransformation::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + FunctionInliningContext ctx(item); + + int frame_name = 0; + std::set foutputs; + + GatherOutputs(item, ctx, foutputs); + + //std::cout << foutputs.size() << '\n'; + //for( const auto& str : foutputs ) std::cout << str << '\n'; + + // Nothing to do here. + if (!ctx.HasInlinedFunctions()) { + *optimized_graph = item.graph; + return Status::OK(); + } + + std::unordered_map functions_in; + + // Copying node cause I need to make changes on it + for (NodeDef node : item.graph.node()) { + for (string& input : *node.mutable_input()) { + // If it takes input from a function + if (foutputs.find(input) != foutputs.end()) { + input = ParseString(input); + } + } + + const FunctionDef* func = ctx.FindInlinedFunction(node.op()); + if (func != nullptr) { + FuncInfo func_info; + // All the special nodes of this function and its 'callee-functions' too, + // will colocate in the same device (important for distributed) + string device = node.device(); + functions_in.emplace(node.op(), func_info); + InlineFunction(node, *func, ctx, optimized_graph, functions_in, ++frame_name, device); + functions_in.erase(node.op()); // At this point functions_in will be empty + + // Check if the function node corresponded to some fetch_outputs + // before transformation occurred + NodeDef *idN; + bool created = false; + const std::unordered_map func_attr(node.attr().begin(), node.attr().end()); + + for (size_t i = 0; i < item.fetch.size(); ++i) { + const string &t = item.fetch[i]; + // Parse t into node_name and output_index. + TensorId id(ParseTensorName(t)); + + if (node.name() == id.first) { + + if (created == false) { + idN = optimized_graph->add_node(); + idN->set_op("IdentityN"); + idN->set_name(node.name()); + idN->set_device(device); + + AttrValue::ListValue* type_list = (*idN->mutable_attr())["T"].mutable_list(); + for (const OpDef::ArgDef& arg : func->signature().output_arg()) { + TF_RETURN_IF_ERROR(CopyArgType(node, func_attr, "input", arg, type_list)); + } + + idN->add_input(strings::StrCat(node.name(), "/Ret", id.second)); + + created = true; + } else { + idN->add_input(strings::StrCat(node.name(), "/Ret", id.second)); + } + } + } + } else { + *optimized_graph->add_node() = node; + } + } + + *optimized_graph->mutable_versions() = item.graph.versions(); + *optimized_graph->mutable_library() = item.graph.library(); + + /****************************************************************************************************** + // Dumps optimized graph in a not so readable form + // const GraphDef* tmp = optimized_graph; + // printf("Summarize Optimized Graph\n %s\n", SummarizeGraphDef(*tmp).c_str()); + + // Write an event, so that we can visualize this optimized graph in tensorboard + EventsWriter writer("TRANSFORMATION"); + Event event; + event.set_wall_time(1234); + event.set_step(34); + + const size_t proto_size = optimized_graph->ByteSizeLong(); + void* buf = port::Malloc(proto_size); + if (buf == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '" , + optimized_graph->GetTypeName(), "' and size ", proto_size); + } + optimized_graph->SerializeToArray(buf, proto_size); + const void* bf = buf; + event.set_graph_def(bf, proto_size); + writer.WriteEvent(event); + ******************************************************************************************************/ + + return Status::OK(); +} + +void FunctionTransformation::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, + double result) { + // Nothing to do for FunctionTransformation. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_transformation.h b/tensorflow/core/grappler/optimizers/function_transformation.h new file mode 100644 index 0000000000..8ed60b3061 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_transformation.h @@ -0,0 +1,42 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + + +// Replace function calling nodes with pairs of new 'Call/Return' operators +class FunctionTransformation : public GraphOptimizer { +public: + FunctionTransformation() {} + ~FunctionTransformation() override {} + + string name() const override { return "function_transformation"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_FUNCTION_TRANSFORMATION_H_ diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6718d2d739..cf0345ce8e 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/optimizers/function_transformation.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status.h" @@ -36,6 +37,9 @@ std::unique_ptr MetaOptimizer::NewOptimizer( if (optimizer == "pruning") { graph_optimizer.reset(new ModelPruner()); } + if (optimizer == "function_transformation") { + graph_optimizer.reset(new FunctionTransformation()); + } if (optimizer == "constfold") { graph_optimizer.reset(new ConstantFolding(cpu_device_)); } @@ -62,6 +66,10 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (!cfg_.disable_model_pruning()) { optimizers.push_back(std::unique_ptr(new ModelPruner())); } + if (cfg_.function_transformation() != RewriterConfig::OFF) { + optimizers.push_back( + std::unique_ptr(new FunctionTransformation())); + } if (cfg_.constant_folding() != RewriterConfig::OFF) { optimizers.push_back( std::unique_ptr(new ConstantFolding(cpu_device_))); @@ -92,6 +100,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } } else { std::set available_optimizers = {"pruning", "constfold", + "function_transformation", "layout", "memory", "autoparallel", "arithmetic"}; for (const auto& optimizer : cfg_.optimizers()) { @@ -137,6 +146,7 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, bool MetaOptimizerEnabled(const RewriterConfig& cfg) { return !cfg.disable_model_pruning() || cfg.optimize_tensor_layout() || + cfg.function_transformation() != RewriterConfig::OFF || cfg.constant_folding() != RewriterConfig::OFF || cfg.arithmetic_optimization() != RewriterConfig::OFF || cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 || diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index bb161bf9a4..e5b916d170 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -97,3 +97,20 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "functions", + srcs = [ + "functions.cc", + ], + hdrs = ["functions.h"], + visibility = ["//visibility:public"], + deps = [ +# "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + ], +) diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc new file mode 100644 index 0000000000..4f286ce1c8 --- /dev/null +++ b/tensorflow/core/grappler/utils/functions.cc @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/grappler/utils/functions.h" + +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { + +std::unique_ptr GrapplerItemFromFunctionDef( + const FunctionDef& func, + const std::unordered_map& func_attr, + const FunctionDefLibrary& library) { + if (func.signature().name().empty()) { + LOG(ERROR) << "function name must be specified."; + return nullptr; + } + std::unique_ptr new_item(new GrapplerItem()); + new_item->id = func.signature().name(); + + std::unordered_map port_map; + + // Add the function inputs as placeholder + for (const auto& inp : func.signature().input_arg()) { + NodeDef* ph = new_item->graph.add_node(); + ph->set_name(inp.name()); + ph->set_op("Placeholder"); + if (inp.type() != DT_INVALID) { + (*ph->mutable_attr())["T"].set_type(inp.type()); + } else { + auto it = func_attr.find(inp.type_attr()); + if (it == func_attr.end()) { + LOG(ERROR) << "Unknown type attribute " << inp.type_attr() + << " for function input " << inp.name(); + return nullptr; + } else { + (*ph->mutable_attr())["T"] = it->second; + } + } + port_map[inp.name()] = inp.name(); + } + + // Add the function body to the graph. + FunctionLibraryDefinition func_def(OpRegistry::Global(), library); + + for (const NodeDef& node : func.node_def()) { + NodeDef* new_node = new_item->graph.add_node(); + *new_node = node; + // Replace the placeholder attribute values with the specified value. + for (auto& attr : *new_node->mutable_attr()) { + const string& ph_name = attr.second.placeholder(); + auto it = func_attr.find(ph_name); + if (it != func_attr.end()) { + attr.second = it->second; + } + } + + // Functions use a custom format to encode connectivity. Map these custom + // strings to regular ones. + const OpRegistrationData* registration; + Status status = func_def.LookUp(node.op(), ®istration); + if (!status.ok()) { + LOG(ERROR) << "Op " << node.op() << " not registered: " << status; + return nullptr; + } + + tensorflow::NameRangeMap inputs; + tensorflow::NameRangeMap outputs; + status = tensorflow::NameRangesForNode(node, registration->op_def, &inputs, + &outputs); + if (!status.ok()) { + LOG(ERROR) << "Op " << node.op() << " invalid: " << status; + return nullptr; + } + for (const auto& name_range : outputs) { + string port_prefix = + strings::StrCat(node.name(), ":", name_range.first, ":"); + int index_start = name_range.second.first; + int index_end = name_range.second.second; + for (int i = index_start; i < index_end; ++i) { + string port_id = strings::StrCat(port_prefix, i - index_start); + string port_name = strings::StrCat(node.name(), ":", i); + port_map[port_id] = port_name; + } + } + } + + for (auto& node : *new_item->graph.mutable_node()) { + // Rewrite the inputs to use the normal naming convention. + for (int i = 0; i < node.input_size(); ++i) { + const string& input = node.input(i); + if (IsControlInput(input)) { + // No need to remap control dependencies. + continue; + } else { + auto it = port_map.find(input); + if (it == port_map.end()) { + LOG(ERROR) << "Unknown input: " << input; + return nullptr; + } + node.set_input(i, it->second); + } + } + } + + // Add the function outputs to the list of fetch nodes, taking into account + // the output mapping if any. + for (const auto& out : func.signature().output_arg()) { + auto it = func.ret().find(out.name()); + if (it != func.ret().end()) { + auto it2 = port_map.find(it->second); + if (it2 == port_map.end()) { + LOG(ERROR) << "Unknown output mapping: " << it->first << " to " + << it->second; + return nullptr; + } else { + new_item->fetch.emplace_back(it2->second); + } + } else { + new_item->fetch.emplace_back(out.name()); + } + } + // Add the function inputs to the list of feeds. + for (const auto& inp : func.signature().input_arg()) { + new_item->feed.emplace_back(inp.name(), Tensor()); + } + + return new_item; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h new file mode 100644 index 0000000000..8f9b7d848a --- /dev/null +++ b/tensorflow/core/grappler/utils/functions.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_UTILS_FUNCTIONS_H_ +#define TENSORFLOW_GRAPPLER_UTILS_FUNCTIONS_H_ + +#include +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { + +namespace grappler { + +// Factory method for creating a GrapplerItem from a FunctionDef. +// Returns nullptr if the given function def cannot be converted. +std::unique_ptr GrapplerItemFromFunctionDef( + const FunctionDef& func, + const std::unordered_map& func_attr, + const FunctionDefLibrary& library); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_UTILS_FUNCTIONS_H_ diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index 77d4702d21..62344aca75 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" @@ -32,24 +33,63 @@ void TopologicalSort(GraphDef* graph) { int front = 0; int back = 0; std::unordered_map ready_inputs; + std::unordered_map> returning_nodes; for (int i = 0; i < graph->node_size(); i++) { auto node = graph->mutable_node(i); if (node->input_size() == 0) { ready_nodes.push_back(node); back++; } + bool recursion_merge = false; + if (IsMerge(*node)) { ready_inputs[node] = 0; for (const auto& input : node->input()) { if (IsNextIteration(*output_map.GetNode(input))) { ready_inputs[node]++; } + else if (IsCall(*output_map.GetNode(input))) { + ready_inputs[node] ++; + recursion_merge = true; + } + } + if (recursion_merge) { + ready_inputs[node]--; + recursion_merge = false; + } + + } else if (IsReturn(*node)) { + // Nodes that send their output to "Return" nodes are + // function Returning Nodes and in case of recursive functions + // those nodes are part of graph cycles. + for (const auto& input : node->input()) { + NodeDef *prevNode = output_map.GetNode(input); + // In order to detect the recursion cycles we depend on + // the fact that a recursive function's returning node, + // will be sending outputs to at least 2 "Return" nodes + // with different "call_id" attributes (same "call_id" + // attrs would mean that they belong in the same function call + // but they correspond to different function outputs) + if (!StringPiece(input).starts_with("^")) { + int call_id; + GetNodeAttr(AttrSlice(*node), "call_id", &call_id); + returning_nodes[prevNode].emplace(call_id); + } } + ready_inputs[node] = 0; + } else { ready_inputs[node] = 0; } } + for (const auto& retnode : returning_nodes) { + if (retnode.second.size() > 1) { + // Detected Cycle + ready_inputs[retnode.first]++; + } + } + while (front != back) { auto ready_node = ready_nodes[front]; for (const auto& fanout_pair : output_map.GetOutputs(ready_node->name())) { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index bdc6faefbc..706f5d7b82 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1437,6 +1437,16 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "function_control_ops", + prefix = "function_control_ops", + deps = [ + "//tensorflow/core:function_control_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + cc_library( name = "data_flow", deps = [ diff --git a/tensorflow/core/kernels/function_control_ops.cc b/tensorflow/core/kernels/function_control_ops.cc new file mode 100644 index 0000000000..a22c079102 --- /dev/null +++ b/tensorflow/core/kernels/function_control_ops.cc @@ -0,0 +1,191 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/function_control_ops.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +void CallOpe::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } +} + +REGISTER_KERNEL_BUILDER(Name("Call").Device(DEVICE_CPU), CallOpe); +REGISTER_KERNEL_BUILDER(Name("RefCall").Device(DEVICE_CPU), CallOpe); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Call").Device(DEVICE_GPU).TypeConstraint("T"), CallOpe) +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefCall").Device(DEVICE_GPU).TypeConstraint("T"), CallOpe) + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); +REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_REF_KERNEL(bool); + +#undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL + +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Call").Device(DEVICE_SYCL).TypeConstraint("T"), CallOpe) +REGISTER_SYCL_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#define REGISTER_SYCL_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefCall").Device(DEVICE_SYCL).TypeConstraint("T"), CallOpe) +REGISTER_SYCL_REF_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); + +#undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_REF_KERNEL +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Call") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefCall") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_REF_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +REGISTER_SYCL_HOST_REF_KERNEL(string); +REGISTER_SYCL_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_SYCL_HOST_KERNEL +#undef REGISTER_SYCL_HOST_REF_KERNEL +#endif // TENSORFLOW_USE_SYCL + +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Call") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +#define REGISTER_GPU_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefCall") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + CallOpe) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_REF_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); +REGISTER_GPU_HOST_REF_KERNEL(string); +REGISTER_GPU_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_GPU_HOST_KERNEL +#undef REGISTER_GPU_HOST_REF_KERNEL + +void ReturnOp::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } +} + +REGISTER_KERNEL_BUILDER(Name("Return").Device(DEVICE_CPU), ReturnOp); +REGISTER_KERNEL_BUILDER(Name("RefReturn").Device(DEVICE_CPU), ReturnOp); + +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Return").Device(DEVICE_GPU).TypeConstraint("T"), ReturnOp); +#define REGISTER_GPU_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("RefReturn").Device(DEVICE_GPU).TypeConstraint("T"), ReturnOp); + +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_KERNEL); +REGISTER_GPU_KERNEL(bool); +REGISTER_GPU_REF_KERNEL(bool); + +#undef REGISTER_GPU_KERNEL +#undef REGISTER_GPU_REF_KERNEL + +#ifdef TENSORFLOW_USE_SYCL + #define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Return").Device(DEVICE_SYCL).TypeConstraint("T"), ReturnOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("RefReturn").Device(DEVICE_SYCL).TypeConstraint("T"), ReturnOp); +REGISTER_SYCL_KERNEL(bool); +TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); + +#undef REGISTER_SYCL_KERNEL +#undef REGISTER_SYCL_REF_KERNEL + +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Return") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp); \ + REGISTER_KERNEL_BUILDER(Name("RefReturn") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +#undef REGISTER_SYCL_HOST_KERNEL +#endif // TENSORFLOW_USE_SYCL + +#define REGISTER_GPU_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Return") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp); \ + REGISTER_KERNEL_BUILDER(Name("RefReturn") \ + .Device(DEVICE_GPU) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ReturnOp) + +REGISTER_GPU_HOST_KERNEL(int32); +REGISTER_GPU_HOST_KERNEL(string); + +#undef REGISTER_GPU_HOST_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/function_control_ops.h b/tensorflow/core/kernels/function_control_ops.h new file mode 100644 index 0000000000..b03d3eae9a --- /dev/null +++ b/tensorflow/core/kernels/function_control_ops.h @@ -0,0 +1,47 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_KERNELS_FUNCTION_CONTROL_OPS_H_ +#define TENSORFLOW_KERNELS_FUNCTION_CONTROL_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// A call op has one input and one output. It creates or finds +// the child frame that is uniquely identified by the frame_name, +// and makes its input available to the child frame. +class CallOpe : public OpKernel { +public: + explicit CallOpe(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~CallOpe() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(CallOpe); +}; + +// A Return op has one input and one output. It exits the current +// frame to its parent frame, and makes its input available to the +// parent frame only if it receives a tensor with a specific tag. +class ReturnOp : public OpKernel { +public: + explicit ReturnOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~ReturnOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(ReturnOp); +}; +} // namespace tensorflow + +#endif diff --git a/tensorflow/core/ops/function_control_ops.cc b/tensorflow/core/ops/function_control_ops.cc new file mode 100644 index 0000000000..fbb74aad89 --- /dev/null +++ b/tensorflow/core/ops/function_control_ops.cc @@ -0,0 +1,116 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// -------------------------------------------------------------------------- +REGISTER_OP("Call") + .Input("data: T") + .Output("output: T") + .Attr("T: type") + .Attr("frame_name: string") + .Attr("call_id: int") + .Attr("arg_id: int") + .Attr("is_constant: bool = false") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->UnknownShape()); + + // Handle resource shape / dtype, if present. + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr) { + c->set_output_handle_shapes_and_types(0, *handle_data); + } else { + // Otherwise, propagate shape if output is a constant. + bool is_constant; + TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant)); + if (is_constant) { + c->set_output(0, c->input(0)); + } + } + return Status::OK(); + }) + .Doc(R"Doc( +Creates (or finds) a child frame, and makes `data` available to the child frame. + +This op is used together with `Return` to create recursive calls in the graph. +The unique `frame_name` is used by the `Executor` to identify frames. + +data: The tensor to be made available to the child frame. +frame_name: The name of the child frame. +output: The same tensor as `data`. + +Returns tensors with the same shapes and contents as the input +tensors. + )Doc"); + +REGISTER_OP("RefCall") + .Input("data: Ref(T)") + .Output("output: Ref(T)") + .Attr("T: type") + .Attr("frame_name: string") + .Attr("call_id: int") + .Attr("arg_id: int") + .Attr("is_constant: bool = false") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"Doc( +Creates (or finds) a child frame, and makes `data` available to the child frame. + +This op is used together with `Return` to create recursive calls in the graph. +The unique `frame_name` is used by the `Executor` to identify frames. + +data: The tensor to be made available to the child frame. +frame_name: The name of the child frame. +output: The same tensor as `data`. + +Returns tensors with the same shapes and contents as the input +tensors. + )Doc"); + +// -------------------------------------------------------------------------- +REGISTER_OP("Return") +.Input("data: T") +.Output("output: T") +.Attr("T: type") +.Attr("frame_name: string") +.Attr("call_id: int") +.Attr("arg_id: int") +.SetShapeFn(shape_inference::UnchangedShape) +.Doc(R"Doc( +Exits the current frame to its parent frame. +Exit makes its input `data` available to the parent frame. +data: The list of tensors to be made available to the parent frame. +output: The same list of tensors as `data`. + )Doc"); + +REGISTER_OP("RefReturn") +.Input("data: Ref(T)") +.Output("output: Ref(T)") +.Attr("T: type") +.Attr("frame_name: string") +.Attr("call_id: int") +.Attr("arg_id: int") +.SetShapeFn(shape_inference::UnchangedShape) +.Doc(R"Doc( +Exits the current frame to its parent frame. +Exit makes its input `data` available to the parent frame. +data: The tensors to be made available to the parent frame. +output: The same tensors as `data`. + )Doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 8a8dd3c7d5..ce8fb89611 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -31,6 +31,8 @@ message RewriterConfig { Toggle constant_folding = 3; // Arithmetic optimizations (default is ON) Toggle arithmetic_optimization = 7; + // Function transformation (default is ON). + Toggle function_transformation = 10; // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 7068e72009..9dabf8eda6 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -24,6 +24,7 @@ import collections import hashlib +from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow as c_api @@ -187,6 +188,61 @@ def __call__(self, func): **self._extra_kwargs) +class Declare(object): + """Declares a TensorFlow function. + + The object represents a TensorFlow function which will be defined + later during a graph construction. + + For example, + # Declares a function Foo, which takes a tf.int32 named "n" and a + # tf.float32 named "x" as inputs and returns a tf.float32 named "z" + # as its output. + foo = Declare("Foo", [("n", tf.int32), ("x", tf.float32)], + [("z", tf.float32)]) + + # Defines a function Bar calls Foo. + @tf.Defun(tf.float32) + def Bar(x): + return foo(6, x) + + # Defines Foo, with output named "z". + @tf.Defun(tf.int32, tf.float32, out_names=["z"]) + def Foo(n, x): + ... # Calculation. + return result + """ + + + def __init__(self, func_name, inputs, outputs): + """Creates a `Declare` object. + + Args: + func_name: The name of the function. + inputs: A list of (name, data type) pairs of function arguments. + outputs: A list of (name, data type) pairs of function return values. + """ + self._sig = op_def_pb2.OpDef() + self._sig.name = func_name + + def _to_argdef_list(args): + names = [n for n, t in args] + if len(names) != len(set(names)): + raise ValueError("Expected names to all be unique: %s" % str(names)) + return [ + op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n) + for n, t in args + ] + + self._sig.input_arg.extend(_to_argdef_list(inputs)) + self._sig.output_arg.extend(_to_argdef_list(outputs)) + + def __call__(self, *inputs, **kwargs): + inputs = [ops.convert_to_tensor(_) for _ in inputs] + return _call(self._sig, *inputs, **kwargs)[0] + + + class _DefinedFunction(object): """_DefinedFunction encapsulates a function definition and its properties.