>>> import tensorflow as tf
>>> tf.enable_eager_execution()
>>> params = tf.constant([
... [[1,1],[2,2],[3,3],[4,4],[5,5]],
... [[1,1],[2,2],[3,3],[4,4],[5,5]],
... [[1,1],[2,2],[3,3],[4,4],[5,5]]])
>>> tf.gather(params, indices=[0,1], axis=0)
<tf.Tensor: id=47, shape=(2, 5, 2), dtype=int32, numpy=
array([[[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]],
[[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]]], dtype=int32)>
>>> tf.gather(params, indices=[0,1], axis=1)
<tf.Tensor: id=51, shape=(3, 2, 2), dtype=int32, numpy=
array([[[1, 1],
[2, 2]],
[[1, 1],
[2, 2]],
[[1, 1],
[2, 2]]], dtype=int32)>
>>> tf.gather(params, indices=[0,1], axis=2)
<tf.Tensor: id=57, shape=(3, 5, 2), dtype=int32, numpy=
array([[[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]],
[[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]],
[[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]]], dtype=int32)>
>>>