diff --git a/tensorflow_serving/util/json_tensor.cc b/tensorflow_serving/util/json_tensor.cc index 82532b12e38..3217a82e769 100644 --- a/tensorflow_serving/util/json_tensor.cc +++ b/tensorflow_serving/util/json_tensor.cc @@ -224,6 +224,9 @@ class JsonWriterWithLimit { constexpr int kMaxJsonDebugStringBytes = 256; +// Maximum nesting depth for tensors in JSON. +constexpr int kMaxTensorJsonDepth = 1024; + // Stringify JSON value (only for use in error reporting or debugging). // Large JSON objects are truncated to kMaxJsonDebugStringBytes. string JsonValueToDebugString(const rapidjson::Value& val) { @@ -346,13 +349,18 @@ Status AddValueToTensor(const rapidjson::Value& val, DataType dtype, // `val` can be scalar or list or list of lists with arbitrary nesting. If a // scalar (non array) is passed, we do not add dimension info to shape (as // scalars do not have a dimension). -void GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape) { - if (!val.IsArray()) return; +Status GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape, + int depth = 0) { + if (depth >= kMaxTensorJsonDepth) { + return errors::InvalidArgument("Exceeded maximum tensor JSON nesting depth"); + } + if (!val.IsArray()) return OkStatus(); const auto size = val.Size(); shape->add_dim()->set_size(size); if (size > 0) { - GetDenseTensorShape(val[0], shape); + return GetDenseTensorShape(val[0], shape, depth + 1); } + return OkStatus(); } bool IsValBase64Object(const rapidjson::Value& val) { @@ -391,6 +399,9 @@ Status JsonDecodeBase64Object(const rapidjson::Value& val, // Fills tensor values. Status FillTensorProto(const rapidjson::Value& val, int level, DataType dtype, int* val_count, TensorProto* tensor) { + if (level >= kMaxTensorJsonDepth) { + return errors::InvalidArgument("Exceeded maximum tensor JSON nesting depth"); + } const auto rank = tensor->tensor_shape().dim_size(); if (!val.IsArray()) { // DOM tree for a (dense) tensor will always have all values @@ -453,7 +464,8 @@ Status AddInstanceItem(const rapidjson::Value& item, const string& name, const auto dtype = tensorinfo_map.at(name).dtype(); auto* tensor = &(*tensor_map)[name]; tensor->mutable_tensor_shape()->Clear(); - GetDenseTensorShape(item, tensor->mutable_tensor_shape()); + TF_RETURN_IF_ERROR( + GetDenseTensorShape(item, tensor->mutable_tensor_shape())); TF_RETURN_IF_ERROR( FillTensorProto(item, 0 /* level */, dtype, &size, tensor)); if (!size_map->count(name)) { @@ -623,7 +635,8 @@ Status FillTensorMapFromInputsMap( auto* tensor = &(*tensor_map)[tensorinfo_map.begin()->first]; tensor->set_dtype(tensorinfo_map.begin()->second.dtype()); - GetDenseTensorShape(val, tensor->mutable_tensor_shape()); + TF_RETURN_IF_ERROR( + GetDenseTensorShape(val, tensor->mutable_tensor_shape())); int unused_size = 0; TF_RETURN_IF_ERROR(FillTensorProto(val, 0 /* level */, tensor->dtype(), &unused_size, tensor)); @@ -639,7 +652,8 @@ Status FillTensorMapFromInputsMap( auto* tensor = &(*tensor_map)[name]; tensor->set_dtype(dtype); tensor->mutable_tensor_shape()->Clear(); - GetDenseTensorShape(item->value, tensor->mutable_tensor_shape()); + TF_RETURN_IF_ERROR( + GetDenseTensorShape(item->value, tensor->mutable_tensor_shape())); int unused_size = 0; TF_RETURN_IF_ERROR(FillTensorProto(item->value, 0 /* level */, dtype, &unused_size, tensor)); diff --git a/tensorflow_serving/util/json_tensor_test.cc b/tensorflow_serving/util/json_tensor_test.cc index c117da26c54..d00619a5f1c 100644 --- a/tensorflow_serving/util/json_tensor_test.cc +++ b/tensorflow_serving/util/json_tensor_test.cc @@ -101,6 +101,27 @@ TEST(JsontensorTest, DeeplyNestedWellFormed) { EXPECT_EQ(tmap.size(), 1); } +TEST(JsontensorTest, DeeplyNestedTensorValue) { + TensorInfoMap infomap; + ASSERT_TRUE( + TextFormat::ParseFromString("dtype: DT_INT32", &infomap["default"])); + + PredictRequest req; + JsonPredictRequestFormat format; + std::string json_req = R"({"instances":)"; + int depth = 2000; + json_req.append(depth, '['); + json_req.append("1"); + json_req.append(depth, ']'); + json_req.append("}"); + auto status = + FillPredictRequestFromJson(json_req, getmap(infomap), &req, &format); + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Exceeded maximum tensor JSON nesting depth")); +} + TEST(JsontensorTest, DeeplyNestedMalformed) { TensorInfoMap infomap; ASSERT_TRUE(