这是一篇来自 Google 团队在 KDD 2017 的论文。这里主要记录一下他们的整体思路。
更新历史
- 2018.03.31: 完成初稿
摘要
构建和维护一个用于机器学习模型训练和部署的平台需要精心架构许多组件。主要包括
- Learner - 根据训练数据产生模型
- 分析和验证模型与数据
- 用于部署模型的基础设施
目前大部分这样的系统都是由各种胶水代码组成的,由不同的团队维护,整个系统非常脆弱,因此 Google 做了 TFX。论文会以 Google Play 团队为例子来介绍
1 简介
概念上的机器学习工作流很简单,但是实际情况中有非常多的情况需要处理,如何自动化处理这些异常情况并保证训练正常进行就是机器学习平台需要做的事情。除此之外还需要能够做到以下几点:
- 能够支持大部分的通用任务,以及如果出现比较特别的需求,能够进行扩展支持
- 持续训练和提供服务,比如同样的模型用不同的算法训练,不同的模型用同样的算法训练
- 提供用户界面,用于基础配置、分析数据与模型
- 生产级别的可靠性与可扩展性,能够处理大量数据
2 平台总览
2.1 背景和相关工作
相关工作的经验已证明机器学习算法只是平台的一小部分,数据和模型并行使得分布式系统架构成为唯一的选择。除了简单的组件间连接,整个工作流需要能够快速搭建(甚至是自动构建)。训练了多个模型后,需要把这些模型的信息保存在同一个数据库中。还需要能够让非专家快速上手使用。
2.2 平台设计分析
要点一:一个平台,多种任务。Tensorflow 是核心的算法框架,但实际上平台不应局限于此。除了算法库外,数据分析、验证与可视化工具都需要支持稀疏、稠密或序列数据。模型验证、评估和部署工具也需要支持不同的算法类型(回归、分类、序列)
要点二:持续训练。大部分机器学习流水线可以用工作流(workflow)或是依赖图(dependency grapch)表示,实际上就是以特定的顺序执行一系列操作。TFX 支持几种从 data visitation 到 warm-starting 的持续训练策略。Data visitation 可以是静态或动态的;Warm-starting 会从上一个状态初始化部分模型的参数(用以支持增量训练)
要点三:简单的配置方式与实用工具。
要点四:生产级别的可靠性与可扩展性。模型验证是很重要的,模型验证又包括数据验证(保证不让垃圾模型更新到生产)。上线前还需要验证基础设施的环境。面对大数据场景,Google 团队推荐 Apache Beam,可以很好地处理训练、模型评估和批量预测的需求。
难点在于如何把这么多组件以一个比较统一的方式糅合在一个系统中的同时,还要符合之前提到的规则。
3 数据分析、转换和验证
好的数据是成功的基础。数据通常来源于各个系统,这些数据会因为各种系统可能出现的问题而不规范,所以要提高数据质量,就必须要在数据这一层进行发现、诊断和修复。难点在于需要支持各种数据分析和验证的场景,需要容易部署,需要能够快速做基础的验证。注:如果数据验证出错会告警,那么大部分用户会直接关掉数据验证功能或者根本不看告警。
3.1 数据分析
输入是数据集,输出是一系列描述性的统计数据。对于连续型数据,一般包括分位数、直方图、平均数、标准差;对于离散性,一般是 top-K 的值和频率。还需要支持指定特征切分统计(比如正负样本数量)和特征统计(比如协方差与相关性)。通过观察这些统计值,用户可以大概了解不同的数据集。在持续训练的过程中,数据统计需要高效完成(对于大批量数据可以考虑使用近似算法处理)
3.2 数据转换
在训练或预测前,一般需要对数据的格式进行转换,比如 feature-to-integer mappling(vocabularies),这一步可能还会包括去掉低频词。这一步的重点在于一定要保证训练和预测时对数据做相同的转换,不然会导致数据不一致的问题。
3.3 数据验证
使用 schema 来描述数据规范,需要包括:
- 数据中的特征
- 每个特征的值类型
- 每个特征所需要的样本数
- 每个样本的值范围
- 每个特征的具体领域
如下图所示
每个组来维护自己的数据集的 schema,平台也会提供工具来帮助生成这样的 schema。据此,数据验证模块的核心的设计原则为:
- 用户可以一眼看到数据集中检测出来的异常情况以及覆盖率
- 每个异常情况需要提供简单的描述来告诉用户如何去修正数据。比如告诉用户某条数据某个值不在某个范围内,并给出这条数据的内容。但对于某些不是很直观的数据(如 KL 散度),可能修复起来就困难得多
- 某些情况下出现异常值是因为数据集本身整体趋势的变化,应该提供选项来修改 schema(而非修改数据),平台也会展示如何修改 schema 来消除这些数据异常
- 希望用户能把数据异常放到跟代码 bug 一样的重视程度来处理,这些异常会像 bug 被文档记录下来并跟踪解决的情况
4 模型训练
TFX 的其中过一个核心设计哲学就是把训练产品级优秀模型给流水线化(并尽可能自动化),支持 Tensorflow 的所有功能,只需集成一次,就可以使用各种算法。为了解决大数据集训练占用过多资源的问题,Google 采用 warm-starting 的方法来减少资源消耗。
4.1 Warm-Starting
在某些场景下,模型的时效性很重要,在数据集大的时候,完整训练耗费时间过长,可能根本无法及时更新模型,warm-starting 是在模型质量和模型时效性之间找到的一个平衡。
Warm-Starting 受迁移学习启发,即先由基础数据集训练出基础模型,然后用基础模型的参数去初始化目标模型,最后由目标数据集来训练目标模型。同样的,这种方法也可以用在持续训练的场景,即用之前模型的参数,加上新的数据,进行增量训练。
4.2 高级模型规范 API
实际上是对 Tensorflow API 的再封装(类似于 Keras),这能够带来集大的效率提升(减少代码重复、提高代码质量)。比如说名为 FeatureColumns 的抽象,用于帮助用户去处理模型要用的特征;再比如 Estimator,对于给定的模型,Estimator 会处理训练和评估两个步骤(可以单机也可以分布式),一个简单的例子如下(注:非常接近与 keras)
5 模型评估和验证
自动模型评估主要用于验证模型的有效性,避免因为人工查看不及时而无法感知模型的退化,也不能及时换上新的模型,就影响用户体验。
5.1 定义“好”模型
要进行评估,就需要先知道如何评估,也就是说要有个标准,TFX 的标准有二:
- safe to serve 即模型完整,载入和预测的过程中不会出现错误,满足健壮性要求,不能占用太多的资源
- desired predication quality 模型预测准确率
5.2 评估:用户查看模型质量指标
在做模型 A/B 测试的时候,不可能一开始就直接上生产,所以肯定是拿一部分生产数据离线进行评估,平台会给出一些指标,让用户判断模型质量,比如 AUC,train loss, top-N 准确率等等,离线评测过了之后,再去发布到线上。
5.3 验证:机器判断模型的质量
当一个模型已经上线之后,就会进入机器自动验证模型质量的阶段。主要是评测预测准确率,一旦准确率低于某个基线,就会发送告警。这里的主要难点是如何去设置阈值,因为如果太过于严格,告警消息满天飞,最终没人看;如果太宽松,则无法及时发现问题。根据他们跟产品团队的合作经验,使用比较宽松的阈值可以在没有太多告警的同时发现大部分严重问题。
5.4 切分
评估的数据集支持根据 feature 来区分,比如一个产品团队可能只关心这个模型在美国地区的表现,他们可以设置 Country=US 来进行筛选和评估。
5.5 用户对模型验证的态度
起初 TFX 任务可能大部分团队会希望能有模型验证的功能,但实际调查表明,因为模型验证需要增加太多额外的配置,如果配置不对还会带来各种告警甚至模型无法及时更新,所以大部分团队没有启用。他们的建议是零配置的模型验证,这样大家才会真正去用这个功能。
6 模型部署
模型不上线,之前的努力全都白搭,TFX 采用 tensorflow serving 来进行模型部署。
注:这部分我们目前暂时不会特别涉及(因为私有化部署环境下各有各的部署流程),感兴趣的同学可以去查看原论文。