Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions tensorflow_serving/util/json_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ class JsonWriterWithLimit {

constexpr int kMaxJsonDebugStringBytes = 256;

// Maximum nesting depth for tensors in JSON.
constexpr int kMaxTensorJsonDepth = 1024;

// Stringify JSON value (only for use in error reporting or debugging).
// Large JSON objects are truncated to kMaxJsonDebugStringBytes.
string JsonValueToDebugString(const rapidjson::Value& val) {
Expand Down Expand Up @@ -346,13 +349,18 @@ Status AddValueToTensor(const rapidjson::Value& val, DataType dtype,
// `val` can be scalar or list or list of lists with arbitrary nesting. If a
// scalar (non array) is passed, we do not add dimension info to shape (as
// scalars do not have a dimension).
void GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape) {
if (!val.IsArray()) return;
Status GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape,
int depth = 0) {
if (depth >= kMaxTensorJsonDepth) {
return errors::InvalidArgument("Exceeded maximum tensor JSON nesting depth");
}
if (!val.IsArray()) return OkStatus();
const auto size = val.Size();
shape->add_dim()->set_size(size);
if (size > 0) {
GetDenseTensorShape(val[0], shape);
return GetDenseTensorShape(val[0], shape, depth + 1);
}
return OkStatus();
}

bool IsValBase64Object(const rapidjson::Value& val) {
Expand Down Expand Up @@ -391,6 +399,9 @@ Status JsonDecodeBase64Object(const rapidjson::Value& val,
// Fills tensor values.
Status FillTensorProto(const rapidjson::Value& val, int level, DataType dtype,
int* val_count, TensorProto* tensor) {
if (level >= kMaxTensorJsonDepth) {
return errors::InvalidArgument("Exceeded maximum tensor JSON nesting depth");
}
const auto rank = tensor->tensor_shape().dim_size();
if (!val.IsArray()) {
// DOM tree for a (dense) tensor will always have all values
Expand Down Expand Up @@ -453,7 +464,8 @@ Status AddInstanceItem(const rapidjson::Value& item, const string& name,
const auto dtype = tensorinfo_map.at(name).dtype();
auto* tensor = &(*tensor_map)[name];
tensor->mutable_tensor_shape()->Clear();
GetDenseTensorShape(item, tensor->mutable_tensor_shape());
TF_RETURN_IF_ERROR(
GetDenseTensorShape(item, tensor->mutable_tensor_shape()));
TF_RETURN_IF_ERROR(
FillTensorProto(item, 0 /* level */, dtype, &size, tensor));
if (!size_map->count(name)) {
Expand Down Expand Up @@ -623,7 +635,8 @@ Status FillTensorMapFromInputsMap(

auto* tensor = &(*tensor_map)[tensorinfo_map.begin()->first];
tensor->set_dtype(tensorinfo_map.begin()->second.dtype());
GetDenseTensorShape(val, tensor->mutable_tensor_shape());
TF_RETURN_IF_ERROR(
GetDenseTensorShape(val, tensor->mutable_tensor_shape()));
int unused_size = 0;
TF_RETURN_IF_ERROR(FillTensorProto(val, 0 /* level */, tensor->dtype(),
&unused_size, tensor));
Expand All @@ -639,7 +652,8 @@ Status FillTensorMapFromInputsMap(
auto* tensor = &(*tensor_map)[name];
tensor->set_dtype(dtype);
tensor->mutable_tensor_shape()->Clear();
GetDenseTensorShape(item->value, tensor->mutable_tensor_shape());
TF_RETURN_IF_ERROR(
GetDenseTensorShape(item->value, tensor->mutable_tensor_shape()));
int unused_size = 0;
TF_RETURN_IF_ERROR(FillTensorProto(item->value, 0 /* level */, dtype,
&unused_size, tensor));
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_serving/util/json_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ TEST(JsontensorTest, DeeplyNestedWellFormed) {
EXPECT_EQ(tmap.size(), 1);
}

TEST(JsontensorTest, DeeplyNestedTensorValue) {
TensorInfoMap infomap;
ASSERT_TRUE(
TextFormat::ParseFromString("dtype: DT_INT32", &infomap["default"]));

PredictRequest req;
JsonPredictRequestFormat format;
std::string json_req = R"({"instances":)";
int depth = 2000;
json_req.append(depth, '[');
json_req.append("1");
json_req.append(depth, ']');
json_req.append("}");
auto status =
FillPredictRequestFromJson(json_req, getmap(infomap), &req, &format);
ASSERT_FALSE(status.ok());
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(),
HasSubstr("Exceeded maximum tensor JSON nesting depth"));
}

TEST(JsontensorTest, DeeplyNestedMalformed) {
TensorInfoMap infomap;
ASSERT_TRUE(
Expand Down