# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shlex
import subprocess
import threading
import time
from abc import abstractmethod
from multiprocessing.connection import Client, Listener
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.fl_context_utils import get_serializable_data
from nvflare.fuel.common.multi_process_executor_constants import CommunicateData, CommunicationMetaData
from nvflare.fuel.utils.class_utils import ModuleScanner
from nvflare.fuel.utils.component_builder import ComponentBuilder
from nvflare.fuel.utils.network_utils import get_open_ports
[docs]class WorkerComponentBuilder(ComponentBuilder):
FL_PACKAGES = ["nvflare"]
FL_MODULES = ["client", "app"]
def __init__(self) -> None:
"""Component to build workers."""
super().__init__()
self.module_scanner = ModuleScanner(WorkerComponentBuilder.FL_PACKAGES, WorkerComponentBuilder.FL_MODULES, True)
[docs] def get_module_scanner(self):
return self.module_scanner
[docs]class MultiProcessExecutor(Executor):
def __init__(self, executor_id=None, num_of_processes=1, components=None):
"""Manage the multi-process execution life cycle.
Arguments:
executor_id: executor component ID
num_of_processes: number of processes to create
components: a dictionary for component classes to their arguments
"""
super().__init__()
self.executor_id = executor_id
self.components = {}
self.handlers = []
self._build_components(components)
if not isinstance(num_of_processes, int):
raise TypeError("{} must be an instance of int but got {}".format(num_of_processes, type(num_of_processes)))
if num_of_processes < 1:
raise ValueError(f"{num_of_processes} must >= 1.")
self.num_of_processes = num_of_processes
self.executor = None
self.logger = logging.getLogger(self.__class__.__name__)
self.conn_clients = []
self.exe_process = None
self.open_ports = []
self.stop_execute = False
self.relay_threads = []
self.finalized = False
self.event_lock = threading.Lock()
self.relay_lock = threading.Lock()
[docs] @abstractmethod
def get_multi_process_command(self) -> str:
"""Provide the command for starting multi-process execution.
Returns:
multi-process starting command
"""
return ""
def _build_components(self, components):
component_builder = WorkerComponentBuilder()
for item in components:
cid = item.get("id", None)
if not cid:
raise TypeError("missing component id")
self.components[cid] = component_builder.build_component(item)
if isinstance(self.components[cid], FLComponent):
self.handlers.append(self.components[cid])
[docs] def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.initialize(fl_ctx)
elif event_type == EventType.END_RUN:
self.finalize(fl_ctx)
self._pass_event_to_rank_processes(event_type, fl_ctx)
def _pass_event_to_rank_processes(self, event_type: str, fl_ctx: FLContext):
event_site = fl_ctx.get_prop(FLContextKey.EVENT_ORIGIN_SITE)
if event_site != CommunicateData.SUB_WORKER_PROCESS:
with self.event_lock:
try:
data = {
CommunicationMetaData.COMMAND: CommunicateData.HANDLE_EVENT,
CommunicationMetaData.FL_CTX: get_serializable_data(fl_ctx),
CommunicationMetaData.EVENT_TYPE: event_type,
}
# send the init data to all the child processes
for conn_client in self.conn_clients:
conn_client[CommunicationMetaData.HANDLE_CONN].send(data)
return_data = self.conn_clients[0][CommunicationMetaData.HANDLE_CONN].recv()
# update the fl_ctx from the child process return data.
fl_ctx.props.update(return_data[CommunicationMetaData.FL_CTX].props)
except BaseException:
# Warning: Have to set fire_event=False, otherwise it will cause dead loop on the event handling!!!
self.log_warning(
fl_ctx, f"Failed to relay the event to child processes. Event: {event_type}", fire_event=False
)
[docs] def initialize(self, fl_ctx: FLContext):
self.executor = self.components.get(self.executor_id, None)
if not isinstance(self.executor, Executor):
raise ValueError(
"invalid executor {}: expect Executor but got {}".format(self.executor_id, type(self.executor))
)
self._initialize_multi_process(fl_ctx)
def _initialize_multi_process(self, fl_ctx: FLContext):
try:
self.open_ports = get_open_ports(self.num_of_processes * 3)
client_name = fl_ctx.get_identity_name()
job_id = fl_ctx.get_job_id()
simulate_mode = fl_ctx.get_prop(FLContextKey.SIMULATE_MODE, False)
command = (
self.get_multi_process_command()
+ " -m nvflare.private.fed.app.client.sub_worker_process"
+ " -m "
+ fl_ctx.get_prop(FLContextKey.ARGS).workspace
+ " -c "
+ client_name
+ " -n "
+ job_id
+ " --ports "
+ "-".join([str(i) for i in self.open_ports])
+ " --simulator_engine "
+ str(simulate_mode)
+ " --parent_pid "
+ str(os.getpid())
)
self.logger.info(f"multi_process_executor command: {command}")
# use os.setsid to create new process group ID
self.exe_process = subprocess.Popen(shlex.split(command, " "), preexec_fn=os.setsid, env=os.environ.copy())
for i in range(self.num_of_processes):
listen_port = self.open_ports[i * 3 + 2]
thread = threading.Thread(target=self._relay_fire_event, args=(listen_port, fl_ctx))
self.relay_threads.append(thread)
thread.start()
open_port = self.open_ports[i * 3]
exe_conn = self._create_connection(open_port)
open_port = self.open_ports[i * 3 + 1]
handle_conn = self._create_connection(open_port)
self.conn_clients.append(
{CommunicationMetaData.EXE_CONN: exe_conn, CommunicationMetaData.HANDLE_CONN: handle_conn}
)
self.logger.info(f"Created the connections to child processes on ports: {str(self.open_ports)}")
data = {
CommunicationMetaData.FL_CTX: get_serializable_data(fl_ctx),
CommunicationMetaData.COMPONENTS: self.components,
CommunicationMetaData.HANDLERS: self.handlers,
CommunicationMetaData.LOCAL_EXECUTOR: self.executor,
}
# send the init data to all the child processes
responses = []
for conn_client in self.conn_clients:
conn_client[CommunicationMetaData.EXE_CONN].send(data)
responses.append(False)
while True:
run_abort_signal = fl_ctx.get_run_abort_signal()
if run_abort_signal and run_abort_signal.triggered:
self.finalize(fl_ctx)
break
# Make sure to receive responses from all rank processes.
index = 0
received_all = True
for conn_client in self.conn_clients:
received_all = received_all and responses[index]
if not responses[index]:
if conn_client[CommunicationMetaData.EXE_CONN].poll(0.2):
conn_client[CommunicationMetaData.EXE_CONN].recv()
responses[index] = True
index += 1
if received_all:
break
except:
self.log_exception(fl_ctx, "error initializing multi_process executor")
def _relay_fire_event(self, listen_port, fl_ctx: FLContext):
address = ("localhost", int(listen_port))
listener = Listener(address, authkey=CommunicationMetaData.PARENT_PASSWORD.encode())
conn = listener.accept()
while not self.stop_execute:
try:
if conn.poll(0.1):
data = conn.recv()
event_type = data[CommunicationMetaData.EVENT_TYPE]
rank_number = data[CommunicationMetaData.RANK_NUMBER]
with self.relay_lock:
fl_ctx.props.update(data[CommunicationMetaData.FL_CTX].props)
fl_ctx.set_prop(FLContextKey.FROM_RANK_NUMBER, rank_number, private=True, sticky=False)
fl_ctx.set_prop(
FLContextKey.EVENT_ORIGIN_SITE,
CommunicateData.SUB_WORKER_PROCESS,
private=True,
sticky=False,
)
engine = fl_ctx.get_engine()
engine.fire_event(event_type, fl_ctx)
return_data = {CommunicationMetaData.FL_CTX: get_serializable_data(fl_ctx)}
conn.send(return_data)
except:
self.logger.warning("Failed to relay the fired events from rank_processes.")
def _create_connection(self, open_port):
conn = None
while not conn:
try:
address = ("localhost", open_port)
conn = Client(address, authkey=CommunicationMetaData.CHILD_PASSWORD.encode())
except BaseException:
time.sleep(1.0)
pass
return conn
[docs] def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
if not self.executor:
raise RuntimeError("There's no executor for task {}".format(task_name))
return self._execute_multi_process(
task_name=task_name, shareable=shareable, fl_ctx=fl_ctx, abort_signal=abort_signal
)
def _execute_multi_process(
self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal
) -> Shareable:
if abort_signal.triggered:
self.finalize(fl_ctx)
return make_reply(ReturnCode.OK)
try:
data = {
CommunicationMetaData.COMMAND: CommunicateData.EXECUTE,
CommunicationMetaData.TASK_NAME: task_name,
CommunicationMetaData.SHAREABLE: shareable,
CommunicationMetaData.FL_CTX: get_serializable_data(fl_ctx),
}
# send the execute command to all the child processes
for conn_client in self.conn_clients:
conn_client[CommunicationMetaData.EXE_CONN].send(data)
while True:
if abort_signal.triggered:
self.finalize(fl_ctx)
return make_reply(ReturnCode.OK)
if self.conn_clients[0][CommunicationMetaData.EXE_CONN].poll(1.0):
# Only need to receive the Shareable and FLContext update from rank 0 process.
return_data = self.conn_clients[0][CommunicationMetaData.EXE_CONN].recv()
shareable = return_data[CommunicationMetaData.SHAREABLE]
fl_ctx.props.update(return_data[CommunicationMetaData.FL_CTX].props)
return shareable
except BaseException:
self.log_error(fl_ctx, "Multi-Process Execution error.")
return make_reply(ReturnCode.EXECUTION_RESULT_ERROR)
[docs] def finalize(self, fl_ctx: FLContext):
"""This is called when exiting/aborting the executor."""
if self.finalized:
return
self.finalized = True
self.stop_execute = True
data = {CommunicationMetaData.COMMAND: CommunicateData.CLOSE}
for conn_client in self.conn_clients:
try:
conn_client[CommunicationMetaData.EXE_CONN].send(data)
conn_client[CommunicationMetaData.HANDLE_CONN].send(data)
self.logger.info("close command sent to rank processes.")
except:
self.logger.warning("Failed to send the close command. ")
try:
os.killpg(os.getpgid(self.exe_process.pid), 9)
self.logger.debug("kill signal sent")
except Exception:
pass
if self.exe_process:
self.exe_process.terminate()
# wait for all relay threads to join!
for t in self.relay_threads:
if t.is_alive():
t.join()
self.log_info(fl_ctx, "Multi-Process Executor finalized!", fire_event=False)