diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 4e6089b50a46d..1524ff455ec75 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -415,13 +415,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( */ def writeNextInputToStream(dataOut: DataOutputStream): Boolean - def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions { + def open(outputStream: DataOutputStream): Unit = Utils.logUncaughtExceptions { val isUnixDomainSock = authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED) lazy val sockPath = new File( authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR) .getOrElse(System.getProperty("java.io.tmpdir")), s".${UUID.randomUUID()}.sock") try { + // Buffer the initialization message, and send it together with its length. + val buffer = new ByteArrayOutputStream() + val dataOut = new DataOutputStream(buffer) + // Partition index dataOut.writeInt(partitionIndex) @@ -522,6 +526,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( writeCommand(dataOut) dataOut.flush() + + // The initialization message is complete, write it to the stream with its length. + val messageBytes = buffer.toByteArray + outputStream.writeInt(SpecialLengths.START_OF_INIT_MESSAGE) + outputStream.writeInt(messageBytes.length) + outputStream.write(messageBytes) + outputStream.flush() } catch { case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] => if (context.isCompleted() || context.isInterrupted()) { @@ -1085,6 +1096,7 @@ private[spark] object SpecialLengths { val NULL = -5 val START_ARROW_STREAM = -6 val END_OF_MICRO_BATCH = -7 + val START_OF_INIT_MESSAGE = -8 } private[spark] object BarrierTaskContextMessageProtocol { diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py index 911c50141e43f..8c201d4c25807 100755 --- a/python/packaging/classic/setup.py +++ b/python/packaging/classic/setup.py @@ -267,6 +267,8 @@ def run(self): "pyspark", "pyspark.core", "pyspark.cloudpickle", + "pyspark.messages", + "pyspark.messages.socket", "pyspark.mllib", "pyspark.mllib.linalg", "pyspark.mllib.stat", diff --git a/python/packaging/client/setup.py b/python/packaging/client/setup.py index 17475e9e065ad..182ec11ab2d77 100755 --- a/python/packaging/client/setup.py +++ b/python/packaging/client/setup.py @@ -148,6 +148,8 @@ connect_packages = [ "pyspark", "pyspark.cloudpickle", + "pyspark.messages", + "pyspark.messages.socket", "pyspark.mllib", "pyspark.mllib.linalg", "pyspark.mllib.stat", diff --git a/python/pyspark/messages/__init__.py b/python/pyspark/messages/__init__.py index ccb7b9323257f..69cfbf6bd53a2 100644 --- a/python/pyspark/messages/__init__.py +++ b/python/pyspark/messages/__init__.py @@ -15,8 +15,12 @@ # limitations under the License. # +from pyspark.messages.spark_message_receiver import SparkMessageReceiver from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream +from pyspark.messages.socket.spark_socket_message_receiver import SparkSocketMessageReceiver __all__ = [ + "SparkMessageReceiver", + "SparkSocketMessageReceiver", "ZeroCopyByteStream", ] diff --git a/python/pyspark/messages/socket/__init__.py b/python/pyspark/messages/socket/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/messages/socket/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/messages/socket/spark_socket_message_receiver.py b/python/pyspark/messages/socket/spark_socket_message_receiver.py new file mode 100644 index 0000000000000..fe46d988e8392 --- /dev/null +++ b/python/pyspark/messages/socket/spark_socket_message_receiver.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import BinaryIO + +from pyspark.serializers import read_int, SpecialLengths +from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream +from pyspark.messages.spark_message_receiver import ( + SparkMessageReceiver, +) + + +def _assert_message_id(message_id: int, expected: int) -> None: + assert message_id == expected, ( + f"Expected message with id {expected} " + f"but got message with id {message_id} instead." + ) + + +class SparkSocketMessageReceiver(SparkMessageReceiver): + def __init__(self, infile: BinaryIO): + super().__init__() + self._infile = infile + + def _do_get_init_message(self) -> ZeroCopyByteStream: + message_id = read_int(self._infile) + _assert_message_id(message_id, SpecialLengths.START_OF_INIT_MESSAGE) + + # Read the length and init content + message_length = read_int(self._infile) + message_content = self._infile.read(message_length) + + return ZeroCopyByteStream(memoryview(message_content)) + + def _do_get_data_stream(self) -> BinaryIO: + # For socket communication, we just pass along the underlying socket + # for the data channel. We already stripped the initialization data + # at this state. Therefore, any bytes following this are data bytes. + # + # Note: We deliberately did not introduce a message header for + # data messages to reduce the overhead, especially for small + # batch sizes and real-time-mode (RTM). + return self._infile + + def _do_is_stream_finished(self) -> bool: + # Check if the stream is finished. + # If everything finished properly, we should read a + # 'END_OF_STREAM'. If we read somethign else this means + # the stream has unread data and something went wrong + # during processing. + return read_int(self._infile) == SpecialLengths.END_OF_STREAM diff --git a/python/pyspark/messages/spark_message_receiver.py b/python/pyspark/messages/spark_message_receiver.py new file mode 100644 index 0000000000000..903a2fb114083 --- /dev/null +++ b/python/pyspark/messages/spark_message_receiver.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from enum import Enum +from functools import wraps +from typing import BinaryIO, Callable, TypeVar +from abc import ABC, abstractmethod + +from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream + + +T = TypeVar("T", bound="SparkMessageReceiver") +R = TypeVar("R") + + +class MessageState(Enum): + WAITING_FOR_INIT = 1 + WAITING_FOR_DATA = 2 + WAITING_FOR_FINISH = 3 + DONE = 4 + + +class SparkMessageReceiver(ABC): + """ + Generic class that implements receiving messages from Spark. + Caution: This class is STATEFUL. It is expected, that the + methods of this class are called in the following order: + + 1. Init -> 2. Data stream -> 3. Finish + + This order is verified using assertions in the class. Each function + can be called EXACTLY ONCE in the specified order. + """ + + def __init__(self) -> None: + self._state = MessageState.WAITING_FOR_INIT + + @staticmethod + def _state_transition( + required_state: MessageState, next_state: MessageState + ) -> Callable[[Callable[[T], R]], Callable[[T], R]]: + """Decorator to enforce state transitions.""" + + def decorator(func: Callable[[T], R]) -> Callable[[T], R]: + @wraps(func) + def wrapper(self: T) -> R: + assert self._state == required_state + result = func(self) + self._state = next_state + return result + + return wrapper + + return decorator + + @_state_transition(MessageState.WAITING_FOR_INIT, MessageState.WAITING_FOR_DATA) + def get_init_message(self) -> ZeroCopyByteStream: + """ + Returns: + the binary contents of the initial message as a ZeroCopyByteStream. + """ + return self._do_get_init_message() + + @_state_transition(MessageState.WAITING_FOR_DATA, MessageState.WAITING_FOR_FINISH) + def get_data_stream(self) -> BinaryIO: + """ + Returns: + A binary stream containing the data to invoke the UDF on. + """ + return self._do_get_data_stream() + + @_state_transition(MessageState.WAITING_FOR_FINISH, MessageState.DONE) + def is_stream_finished(self) -> bool: + """ + Checks if a finish message was received + from the JVM. The finish message itself only + has a message id and marks the end of the stream. + If bytes different from the finish id are read + this means something went wrong while consuming the stream. + """ + return self._do_is_stream_finished() + + @abstractmethod + def _do_get_init_message(self) -> ZeroCopyByteStream: + """ + Returns the contents of the init message + as a 'ZeroCopyByteStream'. + + To be implemented by child classes. + """ + pass + + @abstractmethod + def _do_get_data_stream(self) -> BinaryIO: + """ + Returns the Spark data stream. + + To be implemented by child classes. + """ + pass + + @abstractmethod + def _do_is_stream_finished(self) -> bool: + """ + Blocking call that returns whether + the data stream from the JVM is finished. + This is implemented differently, depending + on the transport channel. + + To be implemented by child classes. + """ + pass diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 6de64a1062f0b..48166c948b5b1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -63,6 +63,7 @@ import zlib import itertools import pickle +import codecs pickle_protocol = pickle.HIGHEST_PROTOCOL @@ -84,6 +85,7 @@ class SpecialLengths: END_OF_STREAM = -4 NULL = -5 START_ARROW_STREAM = -6 + START_OF_INIT_MESSAGE = -8 class Serializer: @@ -539,7 +541,7 @@ def loads(self, stream): elif length == SpecialLengths.NULL: return None s = stream.read(length) - return s.decode("utf-8") if self.use_unicode else s + return codecs.decode(s, "utf-8") if self.use_unicode else s def load_stream(self, stream): try: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3b96f02e04b7b..44ea62f626545 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -39,6 +39,7 @@ Union, get_args, get_origin, + BinaryIO, ) T = TypeVar("T") @@ -120,6 +121,10 @@ Conf, ) from pyspark.logger.worker_io import capture_outputs +from pyspark.messages import ( + SparkMessageReceiver, + SparkSocketMessageReceiver, +) class RunnerConf(Conf): @@ -3566,11 +3571,20 @@ def func(_, it): return func, None, ser, ser -@with_faulthandler -def main(infile, outfile): +def invoke_udf(message_receiver: SparkMessageReceiver, outfile: BinaryIO): + """ + This function is the main processing function for worker.py. + It receives messages from the JVM, processes the data, and sends back results. + This method goes through three phases: + + Initialization -> Processing -> Finish/Cleanup + """ try: boot_time = time.time() + # Initialization + infile = message_receiver.get_init_message() init_info = WorkerInitInfo.from_stream(infile) + start_faulthandler_periodic_traceback() check_python_version(init_info.python_version) @@ -3610,8 +3624,13 @@ def main(infile, outfile): init_time = time.time() + # Processing + + # Fetch the input data stream + input_data_stream = message_receiver.get_data_stream() + def process(): - iterator = deserializer.load_stream(infile) + iterator = deserializer.load_stream(input_data_stream) out_iter = func(init_info.split_index, iterator) try: serializer.dump_stream(out_iter, outfile) @@ -3627,6 +3646,7 @@ def process(): process() processing_time_ms = int(1000 * (time.time() - processing_start_time)) + # Cleanup # Reset task context to None. This is a guard code to avoid residual context when worker # reuse. TaskContext._setTaskContext(None) @@ -3644,7 +3664,7 @@ def process(): send_accumulator_updates(outfile) # check end of stream - if read_int(infile) == SpecialLengths.END_OF_STREAM: + if message_receiver.is_stream_finished(): write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker @@ -3652,6 +3672,14 @@ def process(): sys.exit(-1) +@with_faulthandler +def main(infile, outfile): + # Instantiate socket message readers for executing the UDF + socket_reader = SparkSocketMessageReceiver(infile) + + invoke_udf(socket_reader, outfile) + + if __name__ == "__main__": with get_sock_file_to_executor() as sock_file: main(sock_file, sock_file)