diff --git a/sjsonnet/src-js/sjsonnet/CharSWAR.scala b/sjsonnet/src-js/sjsonnet/CharSWAR.scala index ab2b2c5f..659e1233 100644 --- a/sjsonnet/src-js/sjsonnet/CharSWAR.scala +++ b/sjsonnet/src-js/sjsonnet/CharSWAR.scala @@ -23,10 +23,11 @@ object CharSWAR { false } - def isAsciiJsonSafe(s: String): Boolean = { - var i = 0 - val len = s.length - while (i < len) { + def isAsciiJsonSafe(s: String): Boolean = isAsciiJsonSafe(s, 0, s.length) + + def isAsciiJsonSafe(s: String, from: Int, to: Int): Boolean = { + var i = from + while (i < to) { val c = s.charAt(i) if (c < 32 || c == '"' || c == '\\' || c >= 128) return false i += 1 diff --git a/sjsonnet/src-js/sjsonnet/Platform.scala b/sjsonnet/src-js/sjsonnet/Platform.scala index 3d95aa7a..5d15694b 100644 --- a/sjsonnet/src-js/sjsonnet/Platform.scala +++ b/sjsonnet/src-js/sjsonnet/Platform.scala @@ -38,6 +38,9 @@ object Platform { def isAsciiJsonSafe(s: String): Boolean = CharSWAR.isAsciiJsonSafe(s) + def isAsciiJsonSafe(s: String, from: Int, to: Int): Boolean = + CharSWAR.isAsciiJsonSafe(s, from, to) + private def nodeToJson(node: Node): ujson.Value = node match { case _: Node.ScalarNode => YamlDecoder.forAny.construct(node).getOrElse("") match { diff --git a/sjsonnet/src-jvm/sjsonnet/CharSWAR.java b/sjsonnet/src-jvm/sjsonnet/CharSWAR.java index fe67611b..5a22470c 100644 --- a/sjsonnet/src-jvm/sjsonnet/CharSWAR.java +++ b/sjsonnet/src-jvm/sjsonnet/CharSWAR.java @@ -91,12 +91,20 @@ static boolean hasEscapeChar(byte[] arr, int from, int to) { * encoding step: all chars must be printable ASCII excluding {@code '"'} and {@code '\\'}. */ static boolean isAsciiJsonSafe(String str) { - int len = str.length(); + return isAsciiJsonSafe(str, 0, str.length()); + } + + /** + * Range variant of {@link #isAsciiJsonSafe(String)}. Used by the format parser to scan literal + * windows of a format string without allocating substrings. + */ + static boolean isAsciiJsonSafe(String str, int from, int to) { + int len = to - from; if (len < 8) { - return isAsciiJsonSafeScalar(str, len); + return isAsciiJsonSafeScalar(str, from, to); } - int i = 0; - int limit = len - 3; // 4 UTF-16 chars per word + int i = from; + int limit = to - 3; // 4 UTF-16 chars per word while (i < limit) { long word = ((long) str.charAt(i)) | @@ -106,7 +114,7 @@ static boolean isAsciiJsonSafe(String str) { if (swarHasUnsafeAsciiChar(word)) return false; i += 4; } - while (i < len) { + while (i < to) { char c = str.charAt(i); if (c < 32 || c == '"' || c == '\\' || c >= 128) return false; i++; @@ -216,8 +224,8 @@ private static boolean hasEscapeCharScalar(String s, int len) { return false; } - private static boolean isAsciiJsonSafeScalar(String s, int len) { - for (int i = 0; i < len; i++) { + private static boolean isAsciiJsonSafeScalar(String s, int from, int to) { + for (int i = from; i < to; i++) { char c = s.charAt(i); if (c < 32 || c == '"' || c == '\\' || c >= 128) return false; } diff --git a/sjsonnet/src-jvm/sjsonnet/Platform.scala b/sjsonnet/src-jvm/sjsonnet/Platform.scala index 23f329ac..4ba3e07b 100644 --- a/sjsonnet/src-jvm/sjsonnet/Platform.scala +++ b/sjsonnet/src-jvm/sjsonnet/Platform.scala @@ -33,6 +33,9 @@ object Platform { def isAsciiJsonSafe(s: String): Boolean = CharSWAR.isAsciiJsonSafe(s) + def isAsciiJsonSafe(s: String, from: Int, to: Int): Boolean = + CharSWAR.isAsciiJsonSafe(s, from, to) + def gzipBytes(b: Array[Byte]): String = { val outputStream: ByteArrayOutputStream = new ByteArrayOutputStream(b.length) val gzip: GZIPOutputStream = new GZIPOutputStream(outputStream) diff --git a/sjsonnet/src-native/sjsonnet/CharSWAR.scala b/sjsonnet/src-native/sjsonnet/CharSWAR.scala index a9192fdd..63cde049 100644 --- a/sjsonnet/src-native/sjsonnet/CharSWAR.scala +++ b/sjsonnet/src-native/sjsonnet/CharSWAR.scala @@ -73,12 +73,14 @@ object CharSWAR { false } - def isAsciiJsonSafe(s: String): Boolean = { - val len = s.length - if (len < 8) return isAsciiJsonSafeScalar(s, len) + def isAsciiJsonSafe(s: String): Boolean = isAsciiJsonSafe(s, 0, s.length) - var i = 0 - val limit = len - 3 + def isAsciiJsonSafe(s: String, from: Int, to: Int): Boolean = { + val len = to - from + if (len < 8) return isAsciiJsonSafeScalar(s, from, to) + + var i = from + val limit = to - 3 while (i < limit) { val word = (s.charAt(i).toLong) | @@ -88,7 +90,7 @@ object CharSWAR { if (swarHasUnsafeAsciiChar(word)) return false i += 4 } - while (i < len) { + while (i < to) { val c = s.charAt(i) if (c < 32 || c == '"' || c == '\\' || c >= 128) return false i += 1 @@ -166,9 +168,9 @@ object CharSWAR { @inline private def zero16(word: Long): Long = ~((word & U16_HOLE) + U16_HOLE | word | U16_HOLE) - @inline private def isAsciiJsonSafeScalar(s: String, len: Int): Boolean = { - var i = 0 - while (i < len) { + @inline private def isAsciiJsonSafeScalar(s: String, from: Int, to: Int): Boolean = { + var i = from + while (i < to) { val c = s.charAt(i) if (c < 32 || c == '"' || c == '\\' || c >= 128) return false i += 1 diff --git a/sjsonnet/src-native/sjsonnet/Platform.scala b/sjsonnet/src-native/sjsonnet/Platform.scala index c9570722..ab8f580e 100644 --- a/sjsonnet/src-native/sjsonnet/Platform.scala +++ b/sjsonnet/src-native/sjsonnet/Platform.scala @@ -38,6 +38,9 @@ object Platform { def isAsciiJsonSafe(s: String): Boolean = CharSWAR.isAsciiJsonSafe(s) + def isAsciiJsonSafe(s: String, from: Int, to: Int): Boolean = + CharSWAR.isAsciiJsonSafe(s, from, to) + def gzipBytes(b: Array[Byte]): String = { val outputStream: ByteArrayOutputStream = new ByteArrayOutputStream(b.length) val gzip: GZIPOutputStream = new GZIPOutputStream(outputStream) diff --git a/sjsonnet/src/sjsonnet/Format.scala b/sjsonnet/src/sjsonnet/Format.scala index 897fc1ca..e69c60c1 100644 --- a/sjsonnet/src/sjsonnet/Format.scala +++ b/sjsonnet/src/sjsonnet/Format.scala @@ -294,20 +294,13 @@ object Format { } /** - * Scalar ASCII-JSON-safe check over a substring window of `s`. Matches the predicate used by + * ASCII-JSON-safe check over a substring window of `s`. Matches the predicate used by * [[Platform.isAsciiJsonSafe]] (printable ASCII, no `"` or `\`). Used at format-parse time so the * result can be cached on [[RuntimeFormat]] and combined with per-value ASCII-safety at format * time. */ - private def isAsciiJsonSafeRange(s: String, from: Int, to: Int): Boolean = { - var i = from - while (i < to) { - val c = s.charAt(i) - if (c < 32 || c == '"' || c == '\\' || c >= 128) return false - i += 1 - } - true - } + private def isAsciiJsonSafeRange(s: String, from: Int, to: Int): Boolean = + Platform.isAsciiJsonSafe(s, from, to) /** * Hand-written format string scanner. Replaces the fastparse-based parser with direct diff --git a/sjsonnet/src/sjsonnet/stdlib/StringModule.scala b/sjsonnet/src/sjsonnet/stdlib/StringModule.scala index 8cdd64fd..a8ab4304 100644 --- a/sjsonnet/src/sjsonnet/stdlib/StringModule.scala +++ b/sjsonnet/src/sjsonnet/stdlib/StringModule.scala @@ -217,7 +217,10 @@ object StringModule extends AbstractFunctionModule { if (!Character.isValidCodePoint(c)) { Error.fail(s"Invalid unicode code point, got " + c) } - Val.Str(pos, Character.toString(c)) + val s = Character.toString(c) + // Single-codepoint result; ASCII printable except '"' and '\\' is JSON-safe. + if (c >= 0x20 && c < 0x7f && c != '"' && c != '\\') Val.Str.asciiSafe(pos, s) + else Val.Str(pos, s) } } @@ -235,7 +238,13 @@ object StringModule extends AbstractFunctionModule { if (fromForce.isEmpty) { Error.fail("Cannot replace empty string in strReplace") } - Val.Str(pos, str.value.asString.replace(fromForce, to.value.asString)) + val srcVal = str.value + val toVal = to.value + val out = srcVal.asString.replace(fromForce, toVal.asString) + // Result is asciiSafe iff both src and `to` are asciiSafe (`from` is removed). + val srcSafe = srcVal.isInstanceOf[Val.Str] && srcVal.asInstanceOf[Val.Str]._asciiSafe + val toSafe = toVal.isInstanceOf[Val.Str] && toVal.asInstanceOf[Val.Str]._asciiSafe + if (srcSafe && toSafe) Val.Str.asciiSafe(pos, out) else Val.Str(pos, out) } } @@ -369,15 +378,17 @@ object StringModule extends AbstractFunctionModule { */ private object StripChars extends Val.Builtin2("stripChars", "str", "chars") { def evalRhs(str: Eval, chars: Eval, ev: EvalScope, pos: Position): Val = { - Val.Str( - pos, - StripUtils.strip( - str.value.asString, - chars.value.asString, - left = true, - right = true - ) + val v = str.value + val out = StripUtils.strip( + v.asString, + chars.value.asString, + left = true, + right = true ) + v match { + case vs: Val.Str if vs._asciiSafe => Val.Str.asciiSafe(pos, out) + case _ => Val.Str(pos, out) + } } } @@ -390,15 +401,17 @@ object StringModule extends AbstractFunctionModule { */ private object LStripChars extends Val.Builtin2("lstripChars", "str", "chars") { def evalRhs(str: Eval, chars: Eval, ev: EvalScope, pos: Position): Val = { - Val.Str( - pos, - StripUtils.strip( - str.value.asString, - chars.value.asString, - left = true, - right = false - ) + val v = str.value + val out = StripUtils.strip( + v.asString, + chars.value.asString, + left = true, + right = false ) + v match { + case vs: Val.Str if vs._asciiSafe => Val.Str.asciiSafe(pos, out) + case _ => Val.Str(pos, out) + } } } @@ -411,15 +424,17 @@ object StringModule extends AbstractFunctionModule { */ private object RStripChars extends Val.Builtin2("rstripChars", "str", "chars") { def evalRhs(str: Eval, chars: Eval, ev: EvalScope, pos: Position): Val = { - Val.Str( - pos, - StripUtils.strip( - str.value.asString, - chars.value.asString, - left = false, - right = true - ) + val v = str.value + val out = StripUtils.strip( + v.asString, + chars.value.asString, + left = false, + right = true ) + v match { + case vs: Val.Str if vs._asciiSafe => Val.Str.asciiSafe(pos, out) + case _ => Val.Str(pos, out) + } } } @@ -435,21 +450,22 @@ object StringModule extends AbstractFunctionModule { private object Join extends Val.Builtin2("join", "sep", "arr") { private def joinedRepeatedString( pos: Position, - sep: String, - str: String, + sep: Val.Str, + str: Val.Str, count: Int): Val.Str = { - if (count == 0) Val.Str(pos, "") + if (count == 0) Val.Str.asciiSafe(pos, "") else { - val resultLen = str.length.toLong * count + sep.length.toLong * (count - 1) + val s = str.str + val sepStr = sep.str + val resultLen = s.length.toLong * count + sepStr.length.toLong * (count - 1) if (resultLen > Int.MaxValue) Error.fail("String is too large to join") - if (count == 1) { - if (Platform.isAsciiJsonSafe(str)) Val.Str.asciiSafe(pos, str) else Val.Str(pos, str) - } else { - val asciiSafe = Platform.isAsciiJsonSafe(str) && Platform.isAsciiJsonSafe(sep) + if (count == 1) str + else { + val asciiSafe = str._asciiSafe && sep._asciiSafe val b = new java.lang.StringBuilder(resultLen.toInt) - if (str.length + sep.length <= 64) { - val repeated = str + sep + if (s.length + sepStr.length <= 64) { + val repeated = s + sepStr var i = 1 while (i < count) { b.append(repeated) @@ -458,12 +474,12 @@ object StringModule extends AbstractFunctionModule { } else { var i = 1 while (i < count) { - b.append(str) - b.append(sep) + b.append(s) + b.append(sepStr) i += 1 } } - b.append(str) + b.append(s) val result = b.toString if (asciiSafe) Val.Str.asciiSafe(pos, result) else Val.Str(pos, result) @@ -473,14 +489,14 @@ object StringModule extends AbstractFunctionModule { private def joinRepeatedStringEval( pos: Position, - sep: String, + sep: Val.Str, elem: Eval, len: Int): Val.Str = { - if (len == 0) return Val.Str(pos, "") + if (len == 0) return Val.Str.asciiSafe(pos, "") elem match { - case _: Val.Null => Val.Str(pos, "") - case s: Val.Str => joinedRepeatedString(pos, sep, s.str, len) + case _: Val.Null => Val.Str.asciiSafe(pos, "") + case s: Val.Str => joinedRepeatedString(pos, sep, s, len) case _: Val => null case _ => null } @@ -488,7 +504,7 @@ object StringModule extends AbstractFunctionModule { private def joinRepeatedDirectString( pos: Position, - sep: String, + sep: Val.Str, direct: Array[Eval], len: Int): Val.Str = { val firstEval = direct(0) @@ -498,38 +514,151 @@ object StringModule extends AbstractFunctionModule { else null } + private final val PresizedStringJoinMinParts = 16 + + private def joinPresizedStringArray( + pos: Position, + sep: Val.Str, + arr: Val.Arr, + len: Int): Val.Str = { + val sepStr = sep.str + val sepLen = sepStr.length + var totalLen = 0L + var added = false + var asciiSafe = true + var i = 0 + while (i < len) { + arr.value(i) match { + case _: Val.Null => + case x: Val.Str => + if (added) { + totalLen += sepLen + asciiSafe &&= sep._asciiSafe + } + val str = x.str + totalLen += str.length + if (totalLen > Int.MaxValue) Error.fail("String is too large to join") + asciiSafe &&= x._asciiSafe + added = true + case x => Error.fail("Cannot join " + x.prettyName) + } + i += 1 + } + + if (!added) return Val.Str.asciiSafe(pos, "") + + val b = new java.lang.StringBuilder(totalLen.toInt) + i = 0 + var needsSep = false + while (i < len) { + arr.value(i) match { + case _: Val.Null => + case x: Val.Str => + if (needsSep) b.append(sepStr) + needsSep = true + b.append(x.str) + case _ => + } + i += 1 + } + val result = b.toString + if (asciiSafe) Val.Str.asciiSafe(pos, result) else Val.Str(pos, result) + } + + private def joinDirectStringArray( + pos: Position, + sep: Val.Str, + direct: Array[Eval], + len: Int): Val.Str = { + val sepStr = sep.str + val sepLen = sepStr.length + var totalLen = 0L + var elemCount = 0 + var asciiSafe = true + var i = 0 + // Pass 1: validate element types, accumulate total char length and asciiSafe. + while (i < len) { + direct(i) match { + case _: Val.Null => + case x: Val.Str => + totalLen += x.str.length + if (totalLen > Int.MaxValue) Error.fail("String is too large to join") + asciiSafe &&= x._asciiSafe + elemCount += 1 + case _ => return null + } + i += 1 + } + if (elemCount == 0) return Val.Str.asciiSafe(pos, "") + if (elemCount > 1) { + totalLen += sepLen.toLong * (elemCount - 1) + if (totalLen > Int.MaxValue) Error.fail("String is too large to join") + asciiSafe &&= sep._asciiSafe + } + + val b = new java.lang.StringBuilder(totalLen.toInt) + i = 0 + var needsSep = false + // Pass 2: append. Pass 1 already validated all non-Null entries are Val.Str, so the + // unchecked cast below is safe and avoids a redundant pattern match dispatch. + while (i < len) { + val v = direct(i) + if (!v.isInstanceOf[Val.Null]) { + if (needsSep) b.append(sepStr) + needsSep = true + b.append(v.asInstanceOf[Val.Str].str) + } + i += 1 + } + val result = b.toString + if (asciiSafe) Val.Str.asciiSafe(pos, result) else Val.Str(pos, result) + } + def evalRhs(sep: Eval, _arr: Eval, ev: EvalScope, pos: Position): Val = { val arr = implicitly[ReadWriter[Val.Arr]].apply(_arr.value) sep.value match { case sepStr: Val.Str => - val s = sepStr.str val len = arr.length - val repeatedConst = joinRepeatedStringEval(pos, s, arr.constantEval, len) + val s = sepStr.str + val repeatedConst = joinRepeatedStringEval(pos, sepStr, arr.constantEval, len) if (repeatedConst != null) return repeatedConst - if (len == 0) return Val.Str(pos, "") + if (len == 0) return Val.Str.asciiSafe(pos, "") val direct = arr.directBackingArray if (direct != null) { - val repeated = joinRepeatedDirectString(pos, s, direct, len) + val repeated = joinRepeatedDirectString(pos, sepStr, direct, len) if (repeated != null) return repeated + + val joined = joinDirectStringArray(pos, sepStr, direct, len) + if (joined != null) return joined + } + + if (len >= PresizedStringJoinMinParts) { + return joinPresizedStringArray(pos, sepStr, arr, len) } val b = new java.lang.StringBuilder() var i = 0 var added = false + var asciiSafe = true while (i < len) { arr.value(i) match { case _: Val.Null => case x: Val.Str => - if (added) b.append(s) + if (added) { + b.append(s) + asciiSafe &&= sepStr._asciiSafe + } added = true b.append(x.str) + asciiSafe &&= x._asciiSafe case x => Error.fail("Cannot join " + x.prettyName) } i += 1 } - Val.Str(pos, b.toString) + val result = b.toString + if (asciiSafe) Val.Str.asciiSafe(pos, result) else Val.Str(pos, result) case sep: Val.Arr => val len = arr.length if (len > PresizedCopyMaxParts) { @@ -584,7 +713,12 @@ object StringModule extends AbstractFunctionModule { } } - private def splitLimit(pos: Position, str: String, cStr: String, maxSplits: Int): Array[Eval] = { + private def splitLimit( + pos: Position, + str: String, + cStr: String, + maxSplits: Int, + asciiSafe: Boolean): Array[Eval] = { if (cStr.isEmpty) { Error.fail("Cannot split by an empty string") } @@ -596,12 +730,14 @@ object StringModule extends AbstractFunctionModule { var next = if (maxSplits == 0) -1 else str.indexOf(cStr, start) while (next >= 0 && (maxSplits < 0 || sz < maxSplits)) { - b += Val.Str(pos, str.substring(start, next)) + val piece = str.substring(start, next) + b += (if (asciiSafe) Val.Str.asciiSafe(pos, piece) else Val.Str(pos, piece)) start = next + cStr.length sz += 1 next = if (maxSplits >= 0 && sz >= maxSplits) -1 else str.indexOf(cStr, start) } - b += Val.Str(pos, str.substring(start)) + val tail = str.substring(start) + b += (if (asciiSafe) Val.Str.asciiSafe(pos, tail) else Val.Str(pos, tail)) sz += 1 b.result() } @@ -637,9 +773,14 @@ object StringModule extends AbstractFunctionModule { -1 } - private def splitLimitR(pos: Position, str: String, cStr: String, maxSplits: Int): Array[Eval] = { + private def splitLimitR( + pos: Position, + str: String, + cStr: String, + maxSplits: Int, + asciiSafe: Boolean): Array[Eval] = { if (maxSplits == -1) { - return splitLimit(pos, str, cStr, maxSplits) + return splitLimit(pos, str, cStr, maxSplits, asciiSafe) } if (cStr.isEmpty) { @@ -647,7 +788,7 @@ object StringModule extends AbstractFunctionModule { } if (maxSplits >= 0 && maxSplits <= SplitLimitRPreallocMaxSplits) { - return splitLimitRBounded(pos, str, cStr, maxSplits) + return splitLimitRBounded(pos, str, cStr, maxSplits, asciiSafe) } val b = new mutable.ArrayBuilder.ofRef[Eval] @@ -657,14 +798,16 @@ object StringModule extends AbstractFunctionModule { var next = if (maxSplits == 0) -1 else lastSplitIndex(str, cStr, cLen, end - cLen) while (next >= 0 && (maxSplits < 0 || sz < maxSplits)) { - b += Val.Str(pos, str.substring(next + cLen, end)) + val piece = str.substring(next + cLen, end) + b += (if (asciiSafe) Val.Str.asciiSafe(pos, piece) else Val.Str(pos, piece)) end = next sz += 1 next = if (maxSplits >= 0 && sz >= maxSplits) -1 else lastSplitIndex(str, cStr, cLen, end - cLen) } - b += Val.Str(pos, str.substring(0, end)) + val head = str.substring(0, end) + b += (if (asciiSafe) Val.Str.asciiSafe(pos, head) else Val.Str(pos, head)) val result = b.result() var left = 0 @@ -683,7 +826,8 @@ object StringModule extends AbstractFunctionModule { pos: Position, str: String, cStr: String, - maxSplits: Int): Array[Eval] = { + maxSplits: Int, + asciiSafe: Boolean): Array[Eval] = { val cLen = cStr.length val result = new Array[Eval](maxSplits + 1) var out = maxSplits @@ -692,7 +836,8 @@ object StringModule extends AbstractFunctionModule { var next = if (maxSplits == 0) -1 else lastSplitIndex(str, cStr, cLen, end - cLen) while (next >= 0 && sz < maxSplits) { - result(out) = Val.Str(pos, str.substring(next + cLen, end)) + val piece = str.substring(next + cLen, end) + result(out) = if (asciiSafe) Val.Str.asciiSafe(pos, piece) else Val.Str(pos, piece) out -= 1 end = next sz += 1 @@ -700,7 +845,8 @@ object StringModule extends AbstractFunctionModule { if (sz >= maxSplits) -1 else lastSplitIndex(str, cStr, cLen, end - cLen) } - result(out) = Val.Str(pos, str.substring(0, end)) + val head = str.substring(0, end) + result(out) = if (asciiSafe) Val.Str.asciiSafe(pos, head) else Val.Str(pos, head) if (out == 0) result else java.util.Arrays.copyOfRange(result, out, maxSplits + 1) @@ -717,7 +863,9 @@ object StringModule extends AbstractFunctionModule { */ private object Split extends Val.Builtin2("split", "str", "c") { def evalRhs(str: Eval, c: Eval, ev: EvalScope, pos: Position): Val = { - Val.Arr(pos, splitLimit(pos, str.value.asString, c.value.asString, -1)) + val v = str.value + val safe = v.isInstanceOf[Val.Str] && v.asInstanceOf[Val.Str]._asciiSafe + Val.Arr(pos, splitLimit(pos, v.asString, c.value.asString, -1, safe)) } } @@ -733,7 +881,12 @@ object StringModule extends AbstractFunctionModule { */ private object SplitLimit extends Val.Builtin3("splitLimit", "str", "c", "maxsplits") { def evalRhs(str: Eval, c: Eval, maxSplits: Eval, ev: EvalScope, pos: Position): Val = { - Val.Arr(pos, splitLimit(pos, str.value.asString, c.value.asString, maxSplits.value.asInt)) + val v = str.value + val safe = v.isInstanceOf[Val.Str] && v.asInstanceOf[Val.Str]._asciiSafe + Val.Arr( + pos, + splitLimit(pos, v.asString, c.value.asString, maxSplits.value.asInt, safe) + ) } } @@ -746,7 +899,12 @@ object StringModule extends AbstractFunctionModule { */ private object SplitLimitR extends Val.Builtin3("splitLimitR", "str", "c", "maxsplits") { def evalRhs(str: Eval, c: Eval, maxSplits: Eval, ev: EvalScope, pos: Position): Val = { - Val.Arr(pos, splitLimitR(pos, str.value.asString, c.value.asString, maxSplits.value.asInt)) + val v = str.value + val safe = v.isInstanceOf[Val.Str] && v.asInstanceOf[Val.Str]._asciiSafe + Val.Arr( + pos, + splitLimitR(pos, v.asString, c.value.asString, maxSplits.value.asInt, safe) + ) } } @@ -888,8 +1046,15 @@ object StringModule extends AbstractFunctionModule { * Returns a copy of the string in which all ASCII letters are capitalized. */ private object AsciiUpper extends Val.Builtin1("asciiUpper", "str") { - def evalRhs(str: Eval, ev: EvalScope, pos: Position): Val = - Val.Str(pos, asciiUpper(str.value.asString)) + def evalRhs(str: Eval, ev: EvalScope, pos: Position): Val = { + val v = str.value + val s = v.asString + val out = asciiUpper(s) + v match { + case vs: Val.Str if vs._asciiSafe => Val.Str.asciiSafe(pos, out) + case _ => Val.Str(pos, out) + } + } } /** @@ -900,8 +1065,15 @@ object StringModule extends AbstractFunctionModule { * Returns a copy of the string in which all ASCII letters are lower cased. */ private object AsciiLower extends Val.Builtin1("asciiLower", "str") { - def evalRhs(str: Eval, ev: EvalScope, pos: Position): Val = - Val.Str(pos, asciiLower(str.value.asString)) + def evalRhs(str: Eval, ev: EvalScope, pos: Position): Val = { + val v = str.value + val s = v.asString + val out = asciiLower(s) + v match { + case vs: Val.Str if vs._asciiSafe => Val.Str.asciiSafe(pos, out) + case _ => Val.Str(pos, out) + } + } } /** diff --git a/sjsonnet/test/resources/new_test_suite/join_string_presized.jsonnet b/sjsonnet/test/resources/new_test_suite/join_string_presized.jsonnet new file mode 100644 index 00000000..6cbc68a8 --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/join_string_presized.jsonnet @@ -0,0 +1,39 @@ +// Directional coverage for std.join string paths: +// - small array fallback (inline StringBuilder, len < 16) +// - direct backing array path (joinDirectStringArray) +// - presized path (len >= 16, joinPresizedStringArray) +// - asciiSafe propagation (separator and parts both ASCII) +// - non-ASCII parts that should still join correctly +// - null skipping at all positions +// - all-null returns empty string + +local small = std.join("-", ["a", "bb", null, "ccc"]); +local nonAscii = std.join("/", ["é", "λ", null, "🚀"]); +local allNull = std.join("ignored", [null, null]); + +// 20 ASCII parts to force the presized path on a non-direct array. +local many = std.join( + ", ", + std.makeArray(20, function(i) std.toString(i)), +); + +// 18 ASCII parts on a direct (literal) array to exercise joinDirectStringArray. +local direct18 = std.join( + "|", + ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", + "k", "l", "m", "n", "o", "p", "q", "r"], +); + +// Mixed null + ASCII to exercise size pre-walk skipping. +local mixed = std.join( + ":", + std.makeArray(20, function(i) if i % 3 == 0 then null else std.toString(i)), +); + +std.assertEqual(small, "a-bb-ccc") && +std.assertEqual(nonAscii, "é/λ/🚀") && +std.assertEqual(allNull, "") && +std.assertEqual(many, "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19") && +std.assertEqual(direct18, "a|b|c|d|e|f|g|h|i|j|k|l|m|n|o|p|q|r") && +std.assertEqual(mixed, "1:2:4:5:7:8:10:11:13:14:16:17:19") && +true diff --git a/sjsonnet/test/resources/new_test_suite/join_string_presized.jsonnet.golden b/sjsonnet/test/resources/new_test_suite/join_string_presized.jsonnet.golden new file mode 100644 index 00000000..27ba77dd --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/join_string_presized.jsonnet.golden @@ -0,0 +1 @@ +true diff --git a/sjsonnet/test/resources/new_test_suite/string_asciisafe_propagation.jsonnet b/sjsonnet/test/resources/new_test_suite/string_asciisafe_propagation.jsonnet new file mode 100644 index 00000000..227b2d78 --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/string_asciisafe_propagation.jsonnet @@ -0,0 +1,49 @@ +// Directional coverage for ASCII-safety propagation across StringModule builtins. +// Verifies that asciiSafe flag is correctly forwarded / cleared so that ByteRenderer's +// fast path stays correct after each transformation. + +local mj = std.manifestJson; + +// std.char: ASCII codepoint => asciiSafe; non-ASCII codepoint => not asciiSafe. +std.assertEqual(mj(std.char(65)), '"A"') && +std.assertEqual(mj(std.char(233)), '"é"') && +// Control / quote / backslash codepoints route through escape path. +std.assertEqual(mj(std.char(34)), '"\\""') && +std.assertEqual(mj(std.char(92)), '"\\\\"') && +std.assertEqual(mj(std.char(10)), '"\\n"') && + +// std.asciiUpper / asciiLower preserve asciiness — ASCII input stays asciiSafe. +std.assertEqual(mj(std.asciiUpper("hello")), '"HELLO"') && +std.assertEqual(mj(std.asciiLower("WORLD")), '"world"') && +// Non-ASCII input still renders correctly (UTF-8 path). +std.assertEqual(mj(std.asciiUpper("héllo")), '"HéLLO"') && +std.assertEqual(mj(std.asciiLower("WÖRLD")), '"wÖrld"') && + +// std.strReplace: result asciiSafe iff src and `to` are. +std.assertEqual(mj(std.strReplace("hello world", "world", "everyone")), '"hello everyone"') && +// Replace into ASCII source with non-ASCII `to` — must NOT be marked asciiSafe. +std.assertEqual(mj(std.strReplace("hello world", "world", "wörld")), '"hello wörld"') && +// Non-ASCII source — must NOT be marked asciiSafe. +std.assertEqual(mj(std.strReplace("héllo world", "world", "everyone")), '"héllo everyone"') && + +// std.lstripChars / rstripChars / stripChars preserve asciiness. +std.assertEqual(mj(std.lstripChars(" hello", " ")), '"hello"') && +std.assertEqual(mj(std.rstripChars("hello ", " ")), '"hello"') && +std.assertEqual(mj(std.stripChars(" hello ", " ")), '"hello"') && +// Strip on non-ASCII content still renders correctly. +std.assertEqual(mj(std.stripChars(" héllo ", " ")), '"héllo"') && + +// std.split / splitLimit / splitLimitR: verify each element renders correctly. +std.assertEqual(std.split("a,b,c", ","), ["a", "b", "c"]) && +std.assertEqual(std.splitLimit("a,b,c,d", ",", 2), ["a", "b", "c,d"]) && +std.assertEqual(std.splitLimitR("a,b,c,d", ",", 2), ["a,b", "c", "d"]) && +// Splits of non-ASCII string still yield correct elements. +std.assertEqual(std.split("á,b,c", ","), ["á", "b", "c"]) && +// Split with quote / backslash chars stays correct after split. +std.assertEqual(std.split("a\"b,c", ","), ["a\"b", "c"]) && +// Each element's manifestJson representation +std.assertEqual(mj(std.split("a,b,c", ",")[0]), '"a"') && +std.assertEqual(mj(std.split("á,b,c", ",")[0]), '"á"') && +std.assertEqual(mj(std.split("a\"b,c", ",")[0]), '"a\\"b"') && + +true diff --git a/sjsonnet/test/resources/new_test_suite/string_asciisafe_propagation.jsonnet.golden b/sjsonnet/test/resources/new_test_suite/string_asciisafe_propagation.jsonnet.golden new file mode 100644 index 00000000..27ba77dd --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/string_asciisafe_propagation.jsonnet.golden @@ -0,0 +1 @@ +true