0%

【联邦学习之旅】C2 FATE Flow 关键过程源码解析

上一篇文章中,我们通过源码来初步了解了一下调度器 FATE Flow 的工作原理,内容比较多且杂。而在实际的工作中,我们往往需要关注的是更加细致的处理细节,本文将挑选一些要点来进行解析和说明。

注:本文基于 FATE 1.6.0 版本,后续版本的代码将另外标注出变化。

调度器如何判断 Task 执行完成

入口文件:python/fate_flow/scheduler/dag_scheduler.py,入口函数 schedule_running_job -> TaskScheduler.schedule -> TaskScheduler.federated_task_status。在这个函数中我们会得到对应 job 所有 task 的状态,并基于此进行整个 job 状态的计算。我们在日志中经常看到的

1
2
[INFO] [2021-09-22 18:35:12,346] [33677:140465483446016] - dag_scheduler.py[line:310]: Job 202109221624306330774 status is running, calculate by task status list: ['success', 'success', 'running']
[INFO] [2021-09-22 18:35:14,400] [33677:140465483446016] - task_scheduler.py[line:143]: job 202109221624306330774 task 202109221624306330774_dataio_0 0 status is success, calculate by task party status list: ['success', 'success']

这里的具体状态,是直接从 mysql 中对应表中取出来的。那么我们就引入了一个新的问题,具体一个任务在执行的时候,能让调度器知道自己执行完成了呢?

要回答这个问题,我们先来看看调度器是如何知道一个任务启动起来的,很简单,因为是调度器发起的,调度器自然知道呀,直接将状态设置为 RUNNING 即可

1
2
task.f_status = TaskStatus.RUNNING
update_status = JobSaver.update_task_status(task_info=task.to_human_model_dict(only_primary_with=["status"]))

注:这里修改状态和实际发起任务并不是一个原子操作,即存在一种可能,task.f_status 已经变成 RUNNING 但实际上计算进程没有拉起,所以会另外用一个调度的状态 SchedulingStatusCode.SUCCESS 来进行跟踪。拉起失败后会再次更新 task.f_status 的状态为 StatusSet.FAILED,并再次和各个 party 同步状态。

知道了如何发起任务,我们可以依葫芦画瓢,直接搜索 TaskStatus.SUCCESSTaskStatus.FAILED(这是一个很方便的举一反三的阅读源码方法,小本本记一下)。搜索结果并不多,只看赋值的语句就只有一条,位于 python/fate_flow/operation/task_executor.py -> run_task() 中,具体如下:

1
2
3
4
5
6
7
...
profile.profile_start()
run_object.run(component_parameters_on_party, task_run_args)
profile.profile_ends()
...
tracker.save_output_model(output_model,task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default')
task_info["party_status"] = TaskStatus.SUCCESS

这里我们可以看到,通过调用 run_object.run() 函数来执行任务,并且在任务执行完成后,更新 task 状态(通过调用 http 接口来更新,具体函数是 report_task_update_to_driver)。

调度器则是直接通过命令行创建新的进程进行执行,代码如下(直接看 python/fate_flow/operation/task_executor.py -> run_task() 就可以看到对应结果):

1
2
3
4
5
6
7
8
9
10
11
12
13
process_cmd = [
sys.executable,
sys.modules[TaskExecutor.__module__].__file__,
'-j', job_id,
'-n', component_name,
'-t', task_id,
'-v', task_version,
'-r', role,
'-p', party_id,
'-c', task_parameters_path,
'--run_ip', RuntimeConfig.JOB_SERVER_HOST,
'--job_server', '{}:{}'.format(RuntimeConfig.JOB_SERVER_HOST, RuntimeConfig.HTTP_PORT),
]

所以如果一个任务久久没有完成,首先看 nodemanager 中对应 pid 的进程是否存在,如果还存在,说明任务确实还在运行。根据代码的逻辑,一旦出现任何错误,会直接将 task 状态修改为 FAILED 并跳出(除非该进程直接被杀掉)。

Task 的执行过程

通过前面的章节我们可以知道,FateFlow 在接收到 grpc 请求后,task_controller 会直接通过命令行创建新的进程来运行不同的 Task。我们任意找到一个 job 的日志,搜索 python(注意有个空格),就可以看到对应的命令:

1
[INFO] [2021-09-23 15:15:43,681] [23:140533838096128] - job_utils.py[line:334]: start process command: /opt/app-root/bin/python /data/projects/fate/python/fate_flow/operation/task_executor.py -j 2021092315154055290544 -n dataio_0 -t 2021092315154055290544_dataio_0 -v 0 -r guest -p 10001 -c /data/projects/fate/jobs/2021092315154055290544/guest/10001/dataio_0/2021092315154055290544_dataio_0/0/task_parameters.json --run_ip 192.167.0.100 --job_server 192.167.0.100:9380 successfully, pid is 7015

所以接下来我们只要具体关注 task_executor.py -> run_task(),就可以了解 Task 具体的执行过程了,其中 main 函数如下,只有 2 行:

1
2
3
if __name__ == '__main__':
task_info = TaskExecutor.run_task()
TaskExecutor.report_task_update_to_driver(task_info=task_info)

接下来我们就具体来看一下 run_task() 的具体逻辑

执行前准备工作

因为是通过命令行调用,一开始最重要的就是解析命令行参数,直接在日志中搜索 enter task process 就可以找到对应内容:

1
2
[INFO] [2021-09-23 15:15:45,054] [7015:140077791479616] - task_executor.py[line:56]: enter task process
[INFO] [2021-09-23 15:15:45,054] [7015:140077791479616] - task_executor.py[line:57]: Namespace(component_name='dataio_0', config='/data/projects/fate/jobs/2021092315154055290544/guest/10001/dataio_0/2021092315154055290544_dataio_0/0/task_parameters.json', job_id='2021092315154055290544', job_server='192.167.0.100:9380', party_id=10001, role='guest', run_ip='192.167.0.100', task_id='2021092315154055290544_dataio_0', task_version=0)

之后就是把 task 对应的执行信息更新到数据表中(如 party_id, ip, pid 等),初始化 trackertracker_client、运行时配置,配置 session 相关 id computing_session_idfederation_session_id。直接在日志中搜索 Component parameters on party 就可以找到对应内容:

1
2
[INFO] [2021-09-23 15:16:06,027] [7173:140583914858304] - task_executor.py[line:151]: Component parameters on party {'IntersectParam': ..., 'module': 'Intersection', 'output_data_name': ['data']}
[INFO] [2021-09-23 15:16:06,028] [7173:140583914858304] - task_executor.py[line:152]: Task input dsl {'data': {'data': ['dataio_0.data']}}

具体执行过程

执行的入口很简单,就是根据对应的包名,找到对应的类,具体如下:

1
2
3
4
5
6
7
8
9
run_object = getattr(importlib.import_module(run_class_package), run_class_name)()
run_object.set_tracker(tracker=tracker_client)
run_object.set_task_version_id(task_version_id=job_utils.generate_task_version_id(task_id, task_version))
# add profile logs
profile.profile_start()
# 实际执行的函数
run_object.run(component_parameters_on_party, task_run_args)
profile.profile_ends()
output_data = run_object.save_data()

可以看到实际上执行的就是各个 component 的 run 函数(更加具体的执行在后面用实际例子进行说明)。

执行完成后工作

如果执行顺利完成,那么就会导出数据和模型,对应代码为:

1
2
3
4
5
6
7
8
9
# 输出数据
output_data = run_object.save_data()
persistent_table_namespace, persistent_table_name = tracker.save_output_data(
computing_table=output_data[index],
output_storage_engine=job_parameters.storage_engine,
output_storage_address=job_parameters.engines_address.get(EngineType.STORAGE, {}))
# 输出模型
output_model = run_object.export_model()
tracker.save_output_model(output_model, task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default')

最终更新任务状态,并同步给 driver(具体参考上一节)

Component 的执行过程

从上一节我们可以知道,Task 的执行实际上就是调用各个 Component 的 run 函数,那么接下来我们就以三个不同的 component 为例子(本地模式、无 arbiter 两方、有 arbiter 两方),来说明具体的计算是如何执行的。

因为要看具体的 component,我们就需要来到 federatedml 文件夹,根据前面的说明,对应选择的 component 为:

  1. fate_flow/components/upload.py 上传组件,本地模式
  2. federatedml/toy_example/secure_add_[guest|host].py 最简单的双边测试程序,arbiter 两方
  3. federatedml/linear_model/hetero_logistic_regression/hetero_lr_[arbiter|guest|host].py 有 arbiter 的纵向联邦 LR 算法,有 arbiter 两方

上传

该组件使用 local 模式执行,也就是只在本地执行,代码位于:fate_flow/components/upload.py

基类 ComponentBase

首先我们可以看到,任何一个组件都有一个基类 ComponentBase,并且必须要重载 run 函数:

1
2
3
4
5
6
7
# 其他函数没有一一列出
class Upload(ComponentBase):
....
def __init(self):
...
def run(self, component_parameters=None, args=None):
...

因为是第一次讲解 component,我们就花点时间先了解下 ComponentBase 类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class ComponentBase(object):
def __init__(self):
self.task_version_id = ''
self.tracker = None
self.model_output = None
self.data_output = None

def run(self, component_parameters: dict = None, run_args: dict = None):
pass

def set_tracker(self, tracker):
self.tracker = tracker

def save_data(self):
return self.data_output

def export_model(self):
return self.model_output

def set_task_version_id(self, task_version_id):
self.task_version_id = task_version_id

可以看到每个 component 都至少有如下功能:

  1. 执行逻辑 run
  2. 保存数据 save_data
  3. 保存模型 export_model
  4. 设置追踪器 set_tracker

具体上传逻辑

接下来我们就可以回到 Upload.run() 函数中,看看具体如何将数据上传到 FATE,具体步骤如下:

  1. 检查各项任务配置
  2. 如果指定了强制删除同名的表,那么会先检测并删除
  3. 创建指定的表,并保存数据,这里支持不同的存储(Eggroll, MySQL, 本地文件, HDFS),一般来说都是使用 Eggroll,所以需要逐行读取并进行处理
    1. 如果不是是本地文件,则需要逐行读取文件,并将数据上传到数据表中 save_data_table
    2. 如果是本地文件,则直接统计数据量 get_data_table_count

注:具体访问存储都会以 session 的形式访问,这里封装了不同存储的访问模式,用通用的接口供程序调用,这里就不展开。

安全求和

该组件会在 guest 和 host 方分别执行,所以会分为两个代码文件代码位于:federatedml/toy_example/secure_add_[guest|host].py。因为这个涉及到两方通信,也引入了组件参数的概念,所以我们需要额外设定一些变量用于组件设置与变量传输。

组件参数

组件对应的参数可以在 federatedml/param/secure_add_example_param.py 中看到,主要是指定了随机数种子、分区个数和数据量,在实际使用过程中,这里设定的参数将可以通过 json 格式的模块配置指定,具体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 基类是 BaseParam,指定了参数的常用函数,如校验
class SecureAddExampleParam(BaseParam):
def __init__(self, seed=None, partition=1, data_num=1000):
self.seed = seed
self.partition = partition
self.data_num = data_num

def check(self):
if self.seed is not None and type(self.seed).__name__ != "int":
raise ValueError("random seed should be None or integers")

if type(self.partition).__name__ != "int" or self.partition < 1:
raise ValueError("partition should be an integer large than 0")

if type(self.data_num).__name__ != "int" or self.data_num < 1:
raise ValueError("data_num should be an integer large than 0")

传输变量参数

在计算的过程中我们需要在 guest 和 host 之间传输数据,所以需要提前进行定义,对应文件为 federatedml/transfer_variable/transfer_class/secure_add_example_transfer_vairable.py,只有在这里指定的变量,后续才能通过 rollsite 模块进行数据传输,也需要指定传输方向,具体如下:

1
2
3
4
5
6
7
# 基类是 BaseTransferVariables,指定了参数的常用函数,如创建变量等
class SecureAddExampleTransferVariable(BaseTransferVariables):
def __init__(self, flowid=0):
super().__init__(flowid)
self.guest_share = self._create_variable(name='guest_share', src=['guest'], dst=['host'])
self.host_share = self._create_variable(name='host_share', src=['host'], dst=['guest'])
self.host_sum = self._create_variable(name='host_sum', src=['host'], dst=['guest'])

Guest 具体执行逻辑

有了前面的参数设置,我们就可以来一步一步拆解具体的计算了,这里在 run 函数中已经进行了详细的注释,具体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def run(self, component_parameters=None, args=None):
LOGGER.info("begin to init parameters of secure add example guest")
self._init_runtime_parameters(component_parameters)

LOGGER.info("begin to make guest data")
self._init_data()

LOGGER.info("split data into two random parts")
self.secure()

LOGGER.info("share one random part data to host")
self.sync_share_to_host()

LOGGER.info("get share of one random part data from host")
self.recv_share_from_host()

LOGGER.info("begin to get sum of guest and host")
guest_sum = self.add()

LOGGER.info("receive host sum from guest")
host_sum = self.recv_host_sum_from_host()

secure_sum = self.reconstruct(guest_sum, host_sum)

assert (np.abs(secure_sum - self.data_num * 2) < 1e-6)

LOGGER.info("success to calculate secure_sum, it is {}".format(secure_sum))

当然,我们不能满足于只知道具体流程,更要了解具体每个步骤在做什么,简单的讲解如下(更详细的可自行查看源代码,并不难):

  1. 从 json 文件内容中提取并初始化执行时的参数 _init_runtime_parameters
  2. 生成数据 _init_data。假设是 1000 个数据的话,那么就 (0,1) 一直到 (999,1) 共 1000 个键值对,并且通过 session.parallelize 将键值对数组转为计算用的 table,保存在变量 x
  3. 将数据随机拆分为两份 secure(),比如前面的 (0,1) 会被拆分为 (0,0.2) 和 (0,0.8),1000 个都拆分完成后,保存在变量 x1, x2
  4. 将其中一份数据传输给 host sync_share_to_host(),实际上调用的是 guest_share.remote(self.x2, ...),表示把 x2 发送给 host
  5. 从 host 获取一份数据 recv_share_from_host(),实际上调用的是 host_share.get(idx=0),并保存在 y1
  6. 计算 guest 这边所有数据的和 add(),通过 map-reduce 的方式将 x1 与 y1 的所有值相加求和
  7. 获取 host 方所有数据的和 recv_host_sum_from_host(),实际上调用的是 host_sum.get(idx=0)
  8. 计算总的和 reconstruct(),如果与 2000 相差不到 1e-6,则算法执行成功

Host 具体执行逻辑

有了前面的参数设置,我们就可以来一步一步拆解具体的计算了,这里在 run 函数中已经进行了详细的注释,具体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def run(self, component_parameters=None, args=None):
LOGGER.info("begin to init parameters of secure add example host")
self._init_runtime_parameters(component_parameters)

LOGGER.info("begin to make host data")
self._init_data()

LOGGER.info("split data into two random parts")
self.secure()

LOGGER.info("get share of one random part data from guest")
self.recv_share_from_guest()

LOGGER.info("share one random part data to guest")
self.sync_share_to_guest()

LOGGER.info("begin to get sum of host and guest")
host_sum = self.add()

LOGGER.info("send host sum to guest")
self.sync_host_sum_to_guest(host_sum)

因为在 guest 部分已经详细介绍了具体的逻辑,这里简单过一下 host 方的计算,不详细展开了:

  1. 初始化参数、初始化数据和拆分数据都和 guest 方一样
  2. 从 guest 方拉取数据 recv_share_from_guest()
  3. 推送数据到 guest self.sync_share_to_guest()
  4. 计算 host 这边的和 add()
  5. 同步结果给 guest sync_host_sum_to_guest()

逻辑回归

该组件会在 arbiter,guest 和 host 方分别执行,所以会分为三个代码文件代码位于:federatedml/linear_model/hetero_logistic_regression/hetero_lr_[arbiter|guest|host].py

传输变量参数

组件的参数这里就不展开来说明了,随着算法的复杂度上升,具体的配置也会增加,具体参考 federatedml/param/logistic_regression_param.py 文件,而因为有 arbiter 的加入,在配置传输变量的时候会和之前有些不同,参考 federatedml/transfer_variable/transfer_class/hetero_lr_transfer_variable.py 文件,具体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class HeteroLRTransferVariable(BaseTransferVariables):
def __init__(self, flowid=0):
super().__init__(flowid)
self.batch_data_index = self._create_variable(name='batch_data_index', src=['guest'], dst=['host'])
self.batch_info = self._create_variable(name='batch_info', src=['guest'], dst=['host', 'arbiter'])
self.converge_flag = self._create_variable(name='converge_flag', src=['arbiter'], dst=['host', 'guest'])
self.fore_gradient = self._create_variable(name='fore_gradient', src=['guest'], dst=['host'])
self.forward_hess = self._create_variable(name='forward_hess', src=['guest'], dst=['host'])
self.guest_gradient = self._create_variable(name='guest_gradient', src=['guest'], dst=['arbiter'])
self.guest_hess_vector = self._create_variable(name='guest_hess_vector', src=['guest'], dst=['arbiter'])
self.guest_optim_gradient = self._create_variable(name='guest_optim_gradient', src=['arbiter'], dst=['guest'])
self.host_forward_dict = self._create_variable(name='host_forward_dict', src=['host'], dst=['guest'])
self.host_gradient = self._create_variable(name='host_gradient', src=['host'], dst=['arbiter'])
self.host_hess_vector = self._create_variable(name='host_hess_vector', src=['host'], dst=['arbiter'])
self.host_loss_regular = self._create_variable(name='host_loss_regular', src=['host'], dst=['guest'])
self.host_optim_gradient = self._create_variable(name='host_optim_gradient', src=['arbiter'], dst=['host'])
self.host_prob = self._create_variable(name='host_prob', src=['host'], dst=['guest'])
self.host_sqn_forwards = self._create_variable(name='host_sqn_forwards', src=['host'], dst=['guest'])
self.loss = self._create_variable(name='loss', src=['guest'], dst=['arbiter'])
self.loss_intermediate = self._create_variable(name='loss_intermediate', src=['host'], dst=['guest'])
self.paillier_pubkey = self._create_variable(name='paillier_pubkey', src=['arbiter'], dst=['host', 'guest'])
self.sqn_sample_index = self._create_variable(name='sqn_sample_index', src=['guest'], dst=['host'])

Arbiter 特别需要注意的地方

在包含有 arbiter 的算法中,一定需要注意,在 model 的定义中,不能有 transfer_variable 的类型,如果需要使用到传输的数据,需要另外定义一个类,否则会导致 arbiter 方的任务无法正常结束。

如果要从 0 开始开发一个新的算法,那么需要做的步骤还会更多,具体参考官方文档(英文|中文),因为本文不涉及算法开发部分,不再展开,后续会专门写文章进行讲解。

数据是如何传输的

通过前面的章节我们可以看到,除了本地 local 模式之外,其他需要多方的计算,都离不开 transfer_variable 这个类,那么具体数据是如何在 guest/arbiter/host 方传输的呢?本节我们将继续通过源码进行讲解。

我们继续用前面提到的安全求和来进行说明:

1
2
3
4
5
6
7
# 基类是 BaseTransferVariables,指定了参数的常用函数,如创建变量等
class SecureAddExampleTransferVariable(BaseTransferVariables):
def __init__(self, flowid=0):
super().__init__(flowid)
self.guest_share = self._create_variable(name='guest_share', src=['guest'], dst=['host'])
self.host_share = self._create_variable(name='host_share', src=['host'], dst=['guest'])
self.host_sum = self._create_variable(name='host_sum', src=['host'], dst=['guest'])

可以看到这里是通过 _create_varible() 函数来创建用来传输的变量的,经过层层寻找我们可以定位到 fate_arch/federation/transfer_variable/_transfer_variable.py 这个文件,并且实际上变量传输是通过如下两行关键代码实现的:

  • 推送 session.federation.remote(v=obj, name=name, tag=tag, parties=parties, gc=self._remote_gc)
  • 接收 session.federation.get(name=name, tag=tag, parties=parties, gc=self._get_gc)

这里的 federation 是一个基于抽象类 FederationABC 的具体实现,这里我们主要来看看基于 Eggroll 的实现(其他方式类似),对应文件为 fate_arch/federation/eggroll/_federation.py,代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Federation(FederationABC):
def __init__(self, rp_ctx, rs_session_id, party, proxy_endpoint):
LOGGER.debug(f"[federation.eggroll]init federation: "
f"rp_session_id={rp_ctx.session_id}, rs_session_id={rs_session_id}, "
f"party={party}, proxy_endpoint={proxy_endpoint}")

options = {
'self_role': party.role,
'self_party_id': party.party_id,
'proxy_endpoint': proxy_endpoint
}
self._rsc = RollSiteContext(rs_session_id, rp_ctx=rp_ctx, options=options)
LOGGER.debug(f"[federation.eggroll]init federation context done")

def get(self, name, tag, parties, gc):
parties = [(party.role, party.party_id) for party in parties]
raw_result = _get(name, tag, parties, self._rsc, gc)
return [Table(v) if isinstance(v, RollPair) else v for v in raw_result]

def remote(self, v, name, tag, parties, gc):
if isinstance(v, Table):
# noinspection PyProtectedMember
v = v._rp
parties = [(party.role, party.party_id) for party in parties]
_remote(v, name, tag, parties, self._rsc, gc)

从上面的源代码就可以看到之前我们用来推送和接收数据的函数 getremote。在日志中搜索 federation.eggroll]init 就可以看到对应的日志:

1
2
[DEBUG] [2021-09-23 15:19:58,682] [10767:140079964809024] - _federation.py[line:35]: [federation.eggroll]init federation: rp_session_id=2021092315154055290544_hetero_kmeans_0_0_arbiter_10001, rs_session_id=2021092315154055290544_hetero_kmeans_0_0, party=Party(role=arbiter, party_id=10001), proxy_endpoint=rollsite:9370
[DEBUG] [2021-09-23 15:19:58,683] [10767:140079964809024] - _federation.py[line:45]: [federation.eggroll]init federation context done

推送 remote

在推送的时候,最重要的函数是 _push_with_exception_handle,在这里会进行 rpc 调用并根据调用情况执行具体的回调任务:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def _push_with_exception_handle(rsc, v, name, tag, parties):
def _remote_exception_re_raise(f, p):
try:
f.result()
LOGGER.debug(f"[federation.eggroll.remote.{name}.{tag}]future to remote to party: {p} done")
except Exception as e:
pid = os.getpid()
LOGGER.exception(f"[federation.eggroll.remote.{name}.{tag}]future to remote to party: {p} fail,"
f" terminating process(pid={pid})")
import traceback
print(f"federation.eggroll.remote.{name}.{tag} future to remote to party: {p} fail,"
f" terminating process {pid}, traceback: {traceback.format_exc()}")
os.kill(pid, signal.SIGTERM)
raise e

def _get_call_back_func(p):
def _callback(f):
return _remote_exception_re_raise(f, p)

return _callback

rs = rsc.load(name=name, tag=tag)
futures = rs.push(obj=v, parties=parties)
for party, future in zip(parties, futures):
future.add_done_callback(_get_call_back_func(party))
return rs

同样的,我们可以在日志中搜索 [federation.eggroll.remote 查看对应的内容:

1
2
[DEBUG] [2021-09-23 15:20:01,808] [10767:140079964809024] - _federation.py[line:86]: [federation.eggroll.remote.hash.6a4b3c044fae5be3a8a8.uuid.default[('guest', '10001')])]remote object with type: <class 'str'>
[DEBUG] [2021-09-23 15:20:01,812] [10767:140079964809024] - profile.py[line:185]: [federation.remote.federatedml.framework.hetero.procedure.table_aggregator.TableScatterTransVar.RandomPaddingCipherTransVar.UUIDTransVar.uuid.default]arbiter->[Party(role=guest, party_id=10001)] done

获取 get

在获取的时候,会根据对应 parties 的 grpc 连接进行数据拉取,代码如下:

1
2
3
4
5
6
7
8
9
10
rs = rsc.load(name=name, tag=tag)
future_map = dict(zip(rs.pull(parties=parties), parties))
rtn = {}
for future in concurrent.futures.as_completed(future_map):
party = future_map[future]
# 获取结果
v = future.result()
# 获取之后的后处理
rtn[party] = _get_value_post_process(v, name, tag, party, rsc, gc)
return [rtn[party] for party in parties]

对应的日志为:

1
2
DEBUG] [2021-09-23 15:20:10,538] [10767:140079964809024] - _federation.py[line:196]: [federation.eggroll.get.hash.ca11702d7d7e134e6148.p_power_r.default] got object with type: <class 'tuple'>
[DEBUG] [2021-09-23 15:20:10,538] [10767:140079964809024] - profile.py[line:216]: [federation.get.federatedml.framework.hetero.procedure.table_aggregator.TableScatterTransVar.RandomPaddingCipherTransVar.DHTransVar.p_power_r.default]arbiter<-[Party(role=guest, party_id=10001), Party(role=host, party_id=10009), Party(role=host, party_id=10010)] done

简单来说:在 Eggroll 作为 backend 的时候,具体的数据是通过 grpc 的方式进行推送和传输的。

于是又来了一个问题,在 eggroll 的层面上,具体是如何进行数据传输的呢?因为本文主要还是围绕 FateFlow 来介绍,关于 eggroll 的就六道下一篇文章吧。