(四)TensorFlow数据读取

(四)TensorFlow数据读取

TensorFlow主要提供了三中读取数据的方式:

  • 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
  • 文件读取: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
  • 预加载数据:在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。

一、供给数据

供给数据(Feeding)就是之前讲解变量那一节提到的通过feed_dictplaceholder占位符提供值。

1
2
3
4
with tf.Session():
input = tf.placeholder(tf.float32)
classifier = ...
print classifier.eval(feed_dict={input: my_python_preprocessing_fn()})

二、文件读取

当数据集很大,使用此方法可以确保不是所有数据都立即占用内存(如60GB的YouTube-8m数据集)。从文件读取的过程可以通过以下步骤完成:

  1. 使用字符串张量 [“file0”,”file1”] 或者 [(“file%d”i)for in in range(2)] 的方式创建文件命名列表,或者使用 files=tf.train.match_filenames_once('*.JPG') 函数创建。

  2. 将文件名列表交给tf.train.string_input_producer 函数来生成一个先入先出的队列,文件阅读器会需要它来读取数据。

    1
    2
    #string_input_producer提供的可配置参数来设置文件名乱序和最大的训练迭代数
    filename_queue = tf.train.string_input_producer(files)
  3. Reader用于从文件名队列中读取文件:
    根据输入文件格式选择相应的阅读器,然后将文件名队列提供给阅读器的read方法。

  4. Decoder:使用一个或多个解码器和转换操作将值字符串解码为构成训练样本的张量:
    上一步阅读器的read方法会输出一个key来表征输入的文件和其中的纪录(对于调试非常有用),同时得到一个字符串标量,这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

  5. 以CSV格式文件举例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)

    # Default values, in case of empty columns. Also specifies the type of the
    # decoded result.
    record_defaults = [[1], [1], [1], [1], [1]]
    col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
    features = tf.concat(0, [col1, col2, col3, col4])

    with tf.Session() as sess:
    # Start populating the filename quee.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(1200):
    # Retrieve a single instance:
    example, label = sess.run([features, col5])

    coord.request_stop()
    coord.join(threads)

每次read的执行都会从文件中读取一行内容, decode_csv 操作会解析这一行内容并将其转为张量列表。如果输入的参数有缺失,record_default参数可以根据张量的类型来设置默认值。
在调用run或者eval去执行read之前, 必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

三、预加载数据

这仅用于可以完全加载到存储器中的小的数据集。有两种方法:

  • 存储在常量中

    1
    2
    3
    4
    5
    6
    training_data = ...
    training_labels = ...
    with tf.Session as sess:
    x_data = tf.Constant(training_data)
    y_data = tf.Constant(training_labels)
    ...
  • 存储在变量中,初始化后,永远不要改变它的值

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    training_data = ...
    training_labels = ...
    with tf.Session() as sess:
    data_initializer = tf.placeholder(dtype=training_data.dtype,shape=training_data.shape)
    label_initializer = tf.placeholder(dtype=training_labels.dtype,shape=training_labels.shape)
    input_data = tf.Variable(data_initalizer, trainable=False, collections=[])
    input_labels = tf.Variable(label_initalizer, trainable=False, collections=[])
    ...
    sess.run(input_data.initializer,feed_dict={data_initializer: training_data})
    sess.run(input_labels.initializer,feed_dict={label_initializer: training_lables})

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×