TF对应函数使用

TF中在实际应用中最常见的就是求最大值最小值的索引,但是忘过好多次,还是现查的资料,写篇文章记录一下:

TF中求最大值最小值索引的函数为tf.argamax() 和 tf.argmin()

函数使用官方文档(顺便安利一下Kite,太好用了)

tensorflow.argmax
function
SIGNATURE
input,
axis=None,
name=None,
dimension=None,
output_type="<dtype: 'int64'>"
RETURNSTensor | Variable | LabeledTensor | Variable | Base | SparseTensor
HOW OTHERS USED THIS
argmax(​input, axis​)
argmax(​input, axis​)
argmax(​input, axis, name=""predictions""​)
argmax(​input, axis​)
DOCUMENTATION
Returns the index with the largest value across axes of a tensor. (deprecated arguments)

Warning: SOME ARGUMENTS ARE DEPRECATED: `(dimension)`. They will be removed in a future version.
Instructions for updating:
Use the `axis` argument instead

Note that in case of ties the identity of the return value is not guaranteed.

Args:
  input: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`.
  axis: A `Tensor`. Must be one of the following types: `int32`, `int64`.
    int32 or int64, must be in the range `[-rank(input), rank(input))`.
    Describes which axis of the input Tensor to reduce across. For vectors,
    use axis = 0.
  output_type: An optional `tf.DType` from: `tf.int32, tf.int64`. Defaults to `tf.int64`.
  name: A name for the operation (optional).

Returns:
  A `Tensor` of type `output_type`.

第二个参数中axis很容易混淆,axis参数的用法:axis参数的值可以是0,也可以是1,当axis的值是0时,一般来说过用来比较一维数据,也就是按行来比较(一维),当axis的值是1时,通常是多维数组的比较,,就是按列来比较,返回的是最大值的索引。

语法用例

# 最大最小值的索引位置
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


import tensorflow as tf
os.system('cls')

# 打印函数
def log(prefix="",val=""):
    print(prefix,val,'\n')

# 初始化随机数组
a = tf.random.uniform((3,10),minval=0,maxval=10,dtype=tf.int32)
log('a',a)

# 取最大索引位置的数据,通常用于取得模型预测结果
"""
axis参数的用法:axis参数的值可以是0,也可以是1,当axis的值是0时,一般来说过用来比较一维数据,
也就是按行来比较(一维),当axis的值是1时,通常是多维数组的比较,,就是按列来比较,
返回的是最大值的索引。
"""
b = tf.argmax(a,axis=1)
log('a数组axis最大值为1时最大值索引的位置:',b)

# 取得最小值的索引
b = tf.argmin(a,axis=1)
log('a数组axis最大值为1时最小值索引值位置:',b)

知识点视频