完全同意mrry
的回答。
实际上,我将针对该问题发布另一种解决方案。
您可以使用tf.dynamic_partition()
代替tf.gather()
消除警告。
示例代码如下:
# Create the cells for the RNN network
lstm = tf.nn.rnn_cell.BasicLSTMCell(128)
# Get the output and state from dynamic rnn
output, state = tf.nn.dynamic_rnn(lstm, sequence, dtype=tf.float32, sequence_length = seqlen)
# Convert output to a tessor and reshape it
outputs = tf.reshape(tf.pack(output), [-1, lstm.output_size])
# Set partions to 2
num_partitions = 2
# The partitions argument is a tensor which is already fed to a placeholder.
# It is a 1-D tensor with the length of batch_size * max_sequence_length.
# In this partitions tensor, you need to set the last output idx for each seq to 1 and
# others remain 0, so that the result could be separated to two parts,
# one is the last outputs and the other one is the non-last outputs.
res_out = tf.dynamic_partition(outputs, partitions, num_partitions)
# prediction
preds = tf.matmul(res_out[1], weights) + bias
希望这可以对您有所帮助。
0
我最近实现了一个模型,运行时收到以下警告:
使用一些类似的参数设置(嵌入尺寸),模型突然变得很慢。