Source code for nvflare.app_opt.xgboost.histogram_based.controller

# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import os

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import Task
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.workspace import Workspace
from nvflare.fuel.utils.import_utils import optional_import
from nvflare.fuel.utils.network_utils import get_open_ports
from nvflare.security.logging import secure_format_exception, secure_format_traceback

from .constants import XGB_TRAIN_TASK, XGBShareableHeader


[docs]class XGBFedController(Controller): def __init__(self, train_timeout: int = 300, port: int = None): """Federated XGBoost training controller for histogram-base collaboration. Args: train_timeout (int, optional): Time to wait for clients to do local training in seconds. port (int, optional): the port to open XGBoost FL server Raises: TypeError: when any of input arguments does not have correct type ValueError: when any of input arguments is out of range """ super().__init__() if not isinstance(train_timeout, int): raise TypeError("train_timeout must be int but got {}".format(type(train_timeout))) self._port = port self._xgb_fl_server = None self._participate_clients = None self._rank_map = None self._secure = False self._train_timeout = train_timeout self._server_cert_path = None self._server_key_path = None self._ca_cert_path = None self._started = False def _get_certificates(self, fl_ctx: FLContext): workspace: Workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) bin_folder = workspace.get_startup_kit_dir() server_cert_path = os.path.join(bin_folder, "server.crt") if not os.path.exists(server_cert_path): self.log_error(fl_ctx, "Missing server certificate (server.crt)") return False server_key_path = os.path.join(bin_folder, "server.key") if not os.path.exists(server_key_path): self.log_error(fl_ctx, "Missing server key (server.key)") return False ca_cert_path = os.path.join(bin_folder, "rootCA.pem") if not os.path.exists(ca_cert_path): self.log_error(fl_ctx, "Missing ca certificate (rootCA.pem)") return False self._server_cert_path = server_cert_path self._server_key_path = server_key_path self._ca_cert_path = ca_cert_path return True
[docs] def start_controller(self, fl_ctx: FLContext): self.log_info(fl_ctx, f"Initializing {self.__class__.__name__} workflow.") xgb_federated, flag = optional_import(module="xgboost.federated") if not flag: self.log_error(fl_ctx, "Can't import xgboost.federated") return # Assumption: all clients are used clients = self._engine.get_clients() # Sort by client name so rank is consistent clients.sort(key=lambda client: client.name) rank_map = {clients[i].name: i for i in range(0, len(clients))} self._rank_map = rank_map self._participate_clients = clients if not self._port: self._port = get_open_ports(1)[0] self.log_info(fl_ctx, f"Starting XGBoost FL server on port {self._port}") self._secure = self._engine.server.secure_train if self._secure: if not self._get_certificates(fl_ctx): self.log_error(fl_ctx, "Can't get required certificates for XGB FL server in secure mode.") return self._xgb_fl_server = multiprocessing.Process( target=xgb_federated.run_federated_server, args=(self._port, len(clients), self._server_key_path, self._server_cert_path, self._ca_cert_path), ) else: self._xgb_fl_server = multiprocessing.Process( target=xgb_federated.run_federated_server, args=(self._port, len(clients)) ) self._xgb_fl_server.start() self._started = True
[docs] def stop_controller(self, fl_ctx: FLContext): self.cancel_all_tasks() if self._xgb_fl_server: self._xgb_fl_server.terminate() self._started = False
[docs] def process_result_of_unknown_task( self, client: Client, task_name, client_task_id, result: Shareable, fl_ctx: FLContext ): self.log_error(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")
[docs] def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): self.log_info(fl_ctx, "Begin XGBoost training phase.") if not self._started: msg = "Controller does not start successfully." self.log_error(fl_ctx, msg) self.system_panic(msg, fl_ctx) return try: data = Shareable() data.set_header(XGBShareableHeader.WORLD_SIZE, len(self._participate_clients)) data.set_header(XGBShareableHeader.RANK_MAP, self._rank_map) data.set_header(XGBShareableHeader.XGB_FL_SERVER_PORT, self._port) data.set_header(XGBShareableHeader.XGB_FL_SERVER_SECURE, self._secure) train_task = Task( name=XGB_TRAIN_TASK, data=data, timeout=self._train_timeout, ) self.broadcast_and_wait( task=train_task, targets=self._participate_clients, min_responses=len(self._participate_clients), fl_ctx=fl_ctx, abort_signal=abort_signal, ) self.log_info(fl_ctx, "Finish training phase.") except BaseException as e: err = secure_format_traceback() error_msg = f"Exception in control_flow: {secure_format_exception(e)}: {err}" self.log_exception(fl_ctx, error_msg) self.system_panic(secure_format_exception(e), fl_ctx)