Skip to content

Commit

Permalink
Merge branch 'feature/handle-generic-batch-sizes' into dev-0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
GFallasRR committed May 29, 2019
2 parents 5b052dc + 0d9696b commit a9e43eb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions r2i/tensorflow/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ RuntimeError Frame::GetTensorShape (std::shared_ptr<TF_Graph> pgraph,
return error;
}

/* Some tensors may refer to generic batch sizes as -1. If this is the case
fallback to 1 */
if (-1 == (*dims)[0]) {
(*dims)[0] = 1;
}

type = TF_OperationOutputType(output);
size = TF_DataTypeSize(type);

Expand Down
6 changes: 6 additions & 0 deletions r2i/tensorflow/prediction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ RuntimeError Prediction::SetTensor (std::shared_ptr<TF_Graph> pgraph,
return error;
}

/* Some tensors may refer to generic batch sizes as -1. If this is the case
fallback to 1 */
if (-1 == dims[0]) {
dims[0] = 1;
}

TF_DataType type = TF_OperationOutputType(output);
size_t type_size = TF_DataTypeSize(type);
size_t data_size = this->GetRequiredBufferSize (output, dims, num_dims);
Expand Down

0 comments on commit a9e43eb

Please sign in to comment.