我正在使用keras 2.0.8和tensorflow 1.3.0后端.
我在类init中加载一个模型,然后用它来预测多线程.
import tensorflow as tf from keras import backend as K from keras.models import load_model class CNN: def __init__(self,model_path): self.cnn_model = load_model(model_path) self.session = K.get_session() self.graph = tf.get_default_graph() def query_cnn(self,data): X = self.preproccesing(data) with self.session.as_default(): with self.graph.as_default(): return self.cnn_model.predict(X)
我初始化CNN一次,query_cnn方法从多个线程发生.
我在日志中得到的例外是:
File "/home/*/Similarity/CNN.py",line 43,in query_cnn return self.cnn_model.predict(X) File "/usr/local/lib/python3.5/dist-packages/keras/models.py",line 913,in predict return self.model.predict(x,batch_size=batch_size,verbose=verbose) File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1713,in predict verbose=verbose,steps=steps) File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1269,in _predict_loop batch_outs = f(ins_batch) File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py",line 2273,in __call__ **self.session_kwargs) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 895,in run run_Metadata_ptr) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1124,in _run Feed_dict_tensor,options,run_Metadata) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1321,in _do_run options,line 1340,in _do_call raise type(e)(node_def,op,message) tensorflow.python.framework.errors_impl.NotFoundError: PruneForTargets: Some target nodes not found: group_deps
代码在大多数情况下工作正常,它可能是多线程的一些问题.
我该如何解决?
解决方法
确保在创建其他线程之前完成图形创建.
在图表上调用finalize()可以帮助您.
def __init__(self,model_path): self.cnn_model = load_model(model_path) self.session = K.get_session() self.graph = tf.get_default_graph() self.graph.finalize()
更新1:finalize()将使您的图形为只读,以便可以安全地在多个线程中使用.作为副作用,它将帮助您找到无意的行为,有时还会发现内存泄漏,因为当您尝试修改图形时它会引发异常.
想象一下,你有一个线程可以做一个例如输入的热编码. (坏的例子:)
def preprocessing(self,data): one_hot_data = tf.one_hot(data,depth=self.num_classes) return self.session.run(one_hot_data)
# amount of nodes in tf graph print(len(list(tf.get_default_graph().as_graph_def().node)))
但是,如果您首先定义图形不是这种情况(略微更好的代码):
def preprocessing(self,data): # run pre-created operation with self.input as placeholder return self.session.run(self.one_hot_data,Feed_dict={self.input: data})
更新2:根据此thread,您需要在执行多线程之前在keras模型上调用model._make_predict_function().
Keras builds the GPU function the first time you call predict(). That
way,if you never call predict,you save some time and resources.
However,the first time you call predict is slightly slower than every
other time.
更新的代码:
def __init__(self,model_path): self.cnn_model = load_model(model_path) self.cnn_model._make_predict_function() # have to initialize before threading self.session = K.get_session() self.graph = tf.get_default_graph() self.graph.finalize() # make graph read-only
更新3:我做了一个预热概念的证明,因为_make_predict_function()似乎没有按预期工作.
首先我创建了一个虚拟模型:
import tensorflow as tf from keras.layers import * from keras.models import * model = Sequential() model.add(Dense(256,input_shape=(2,))) model.add(Dense(1,activation='softmax')) model.compile(loss='mean_squared_error',optimizer='adam') model.save("dummymodel")
然后在另一个脚本中我加载了该模型并使其在多个线程上运行
import tensorflow as tf from keras import backend as K from keras.models import load_model import threading as t import numpy as np K.clear_session() class CNN: def __init__(self,model_path): self.cnn_model = load_model(model_path) self.cnn_model.predict(np.array([[0,0]])) # warmup self.session = K.get_session() self.graph = tf.get_default_graph() self.graph.finalize() # finalize def preproccesing(self,data): # dummy return data def query_cnn(self,data): X = self.preproccesing(data) with self.session.as_default(): with self.graph.as_default(): prediction = self.cnn_model.predict(X) print(prediction) return prediction cnn = CNN("dummymodel") th = t.Thread(target=cnn.query_cnn,kwargs={"data": np.random.random((500,2))}) th2 = t.Thread(target=cnn.query_cnn,2))}) th3 = t.Thread(target=cnn.query_cnn,2))}) th4 = t.Thread(target=cnn.query_cnn,2))}) th5 = t.Thread(target=cnn.query_cnn,2))}) th.start() th2.start() th3.start() th4.start() th5.start() th2.join() th.join() th3.join() th5.join() th4.join()
评论预热和最终确定的线条我能够重现你的第一个问题