我读了tfrecord文件,将数据提供给网络,训练得很好,一切顺利,我在训练结束时保存了我的模型,所以我可以在以后运行推理.代码的简化版本如下:
""" Training and saving """ training_dataset = tf.contrib.data.TFRecordDataset(training_record) training_dataset = training_dataset.map(ds._path_records_parser) training_dataset = training_dataset.batch(BATCH_SIZE) with tf.name_scope("iterators"): training_iterator = Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes) next_training_element = training_iterator.get_next() training_init_op = training_iterator.make_initializer(training_dataset) def train(num_epochs): # compute for the number of epochs for e in range(1,num_epochs+1): session.run(training_init_op) #initializing iterator here while True: try: images,labels = session.run(next_training_element) session.run(optimizer,Feed_dict={x: images,y_true: labels}) except tf.errors.OutOfRangeError: saver_name = './saved_models/ucf-model' print("Finished Training Epoch {}".format(e)) break """ Restoring """ # restoring the saved model and its variables session = tf.Session() saver = tf.train.import_Meta_graph(r'saved_models\ucf-model.Meta') saver.restore(session,tf.train.latest_checkpoint('.\saved_models')) graph = tf.get_default_graph() # restoring relevant tensors/ops accuracy = graph.get_tensor_by_name("accuracy/Mean:0") #the tensor that when evaluated returns the mean accuracy of the batch testing_iterator = graph.get_operation_by_name("iterators/Iterator") #my iterator used in testing. next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") #the GetNext operator for my iterator # loading my testing set tfrecords testing_dataset = tf.contrib.data.TFRecordDataset(testing_record_path) testing_dataset = testing_dataset.map(ds._path_records_parser,num_threads=4,output_buffer_size=BATCH_SIZE*20) testing_dataset = testing_dataset.batch(BATCH_SIZE) testing_init_op = testing_iterator.make_initializer(testing_dataset) #to initialize the dataset with tf.Session() as session: session.run(testing_init_op) while True: try: images,labels = session.run(next_testing_element) accuracy = session.run(accuracy,Feed_dict={x: test_images,y_true: test_labels}) #error here,x,y_true not defined except tf.errors.OutOfRangeError: break
我的问题主要是我恢复模型.如何将测试数据提供给网络?
>当我使用testing_iterator = graph.get_operation_by_name(“iterators / Iterator”),next_testing_element = graph.get_operation_by_name(“iterators / IteratorGetNext”)恢复我的Iterator时,出现以下错误:
GetNext()失败,因为迭代器尚未初始化.确保在获取下一个元素之前已为此迭代器运行初始化程序操作.
>所以我尝试使用以下方法初始化我的数据集:testing_init_op = testing_iterator.make_initializer(testing_dataset)).我收到此错误:AttributeError:’Operation’对象没有属性’make_initializer’
另一个问题是,由于正在使用迭代器,因此不需要在training_model中使用占位符,因为迭代器直接将数据提供给图形.但是这样,当我将数据提供给“准确度”操作时,如何在第3行到最后一行恢复我的Feed_dict键?
编辑:如果有人可以建议在迭代器和网络输入之间添加占位符的方法,那么我可以尝试通过评估“准确性”张量来运行图形,同时将数据提供给占位符并完全忽略迭代器.
解决方法
也就是说,在创建图表时,您可以这样做
dataset_init_op = iterator.make_initializer(dataset,name='dataset_init')
然后执行以下操作恢复此操作:
dataset_init_op = graph.get_operation_by_name('dataset_init')
这是一个自包含的代码片段,用于比较恢复前后的随机初始化模型的结果.
保存迭代器
np.random.seed(42) data = np.random.random([4,4]) X = tf.placeholder(dtype=tf.float32,shape=[4,4],name='X') dataset = tf.data.Dataset.from_tensor_slices(X) iterator = tf.data.Iterator.from_structure(dataset.output_types,dataset.output_shapes) dataset_next_op = iterator.get_next() # name the operation dataset_init_op = iterator.make_initializer(dataset,name='dataset_init') w = np.random.random([1,4]) W = tf.Variable(w,name='W',dtype=tf.float32) output = tf.multiply(W,dataset_next_op,name='output') sess = tf.Session() saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) sess.run(dataset_init_op,Feed_dict={X:data}) while True: try: print(sess.run(output)) except tf.errors.OutOfRangeError: saver.save(sess,'tmp/',global_step=1002) break
然后您可以恢复相同的模型进行推理,如下所示:
恢复已保存的迭代器
np.random.seed(42) data = np.random.random([4,4]) tf.reset_default_graph() sess = tf.Session() saver = tf.train.import_Meta_graph('tmp/-1002.Meta') ckpt = tf.train.get_checkpoint_state(os.path.dirname('tmp/checkpoint')) saver.restore(sess,ckpt.model_checkpoint_path) graph = tf.get_default_graph() # Restore the init operation dataset_init_op = graph.get_operation_by_name('dataset_init') X = graph.get_tensor_by_name('X:0') output = graph.get_tensor_by_name('output:0') sess.run(dataset_init_op,Feed_dict={X:data}) while True: try: print(sess.run(output)) except tf.errors.OutOfRangeError: break