summaryrefslogtreecommitdiff
path: root/libavfilter/dnn_backend_tf.c
diff options
context:
space:
mode:
Diffstat (limited to 'libavfilter/dnn_backend_tf.c')
-rw-r--r--libavfilter/dnn_backend_tf.c28
1 files changed, 24 insertions, 4 deletions
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c
index ca6472d445..ba959ae3a2 100644
--- a/libavfilter/dnn_backend_tf.c
+++ b/libavfilter/dnn_backend_tf.c
@@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf;
}
-static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
{
- TFModel *tf_model = (TFModel *)model;
+ TF_DataType dt;
+ size_t size;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
+ switch (input->dt) {
+ case DNN_FLOAT:
+ dt = TF_FLOAT;
+ size = sizeof(float);
+ break;
+ case DNN_UINT8:
+ dt = TF_UINT8;
+ size = sizeof(char);
+ break;
+ default:
+ av_assert0(!"should not reach here");
+ }
+
+ return TF_AllocateTensor(dt, input_dims, 4,
+ input_dims[1] * input_dims[2] * input_dims[3] * size);
+}
+
+static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+{
+ TFModel *tf_model = (TFModel *)model;
TF_SessionOptions *sess_opts;
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
@@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char
if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor);
}
- tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4,
- input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float));
+ tf_model->input_tensor = allocate_input_tensor(input);
if (!tf_model->input_tensor){
return DNN_ERROR;
}