在Modal上部署Bert-VITS2模型

受制于原仓库开源协议,本文涉及到的代码全部使用AGPL-3.0协议开源。

在很早之前得知了Modal平台融资的消息,当时就训练了一个Bert-VITS2模型部署到了上面,后来随着模型的版本快速迭代就没去管了。现在终于到了V2.3一个比较稳定的版本,于是训练了三天模型又花了半天时间给部署到了Modal平台上面,现在写篇博客介绍一下,同时开源代码。

Modal是个FaaS平台,目前主要支持Python代码,可以在云端部署函数或者Flask API后,通过HTTP API和SDK去调用对应的函数。FaaS平台并不是一个新鲜东西,但是Modal创新点在于支持GPU且可以按秒计费。这就非常适合我们部署一些模型来按需调用,模型不需要一直占用我们的GPU显存,且调用成本会非常低,单次text2speech大概0.02美元左右。对于一些大模型,平台也提供了A100等大显存GPU。

另外,每个用户每月似乎是有30美元的免费额度。

部署

仓库开源在:KernelErr/Bert-VITS2的modal分支上。主要代码如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import os
import modal
import fastapi
import fastapi.staticfiles
from modal_const import CACHE_PATH
from pydantic import BaseModel
from fastapi.responses import Response

stub = modal.Stub("bert-vits2")
web_app = fastapi.FastAPI()


def download_model_weights():
    import requests
    from huggingface_hub import snapshot_download
    from transformers.utils import move_cache

    model_files = ["config.json", "D_88000.pth", "DUR_88000.pth", "G_88000.pth", "WD_88000.pth"]
    os.makedirs(CACHE_PATH, exist_ok=True)
    for model_file in model_files:
        rsp = requests.get("HOSTURL" + model_file)
        rsp.raise_for_status()
        with open(os.path.join(CACHE_PATH, model_file), "wb") as f:
            for chunk in rsp.iter_content(chunk_size=1024*1024): 
                if chunk:
                    f.write(chunk)

    os.makedirs(CACHE_PATH + "/bert/deberta-v2-large-japanese-char-wwm", exist_ok=True)
    snapshot_download(
        "ku-nlp/deberta-v2-large-japanese-char-wwm",
        local_dir=CACHE_PATH + "/bert/deberta-v2-large-japanese-char-wwm",
    )
    move_cache()

    os.makedirs(CACHE_PATH + "/bert/chinese-roberta-wwm-ext-large", exist_ok=True)
    snapshot_download(
        "hfl/chinese-roberta-wwm-ext-large",
        local_dir=CACHE_PATH + "/bert/chinese-roberta-wwm-ext-large",
    )
    move_cache()

    os.makedirs(CACHE_PATH + "/bert/deberta-v3-large", exist_ok=True)
    snapshot_download(
        "microsoft/deberta-v3-large",
        local_dir=CACHE_PATH + "/bert/deberta-v3-large",
    )
    move_cache()

    os.makedirs(CACHE_PATH + "/bert/deberta-v2-large-japanese", exist_ok=True)
    snapshot_download(
        "ku-nlp/deberta-v2-large-japanese",
        local_dir=CACHE_PATH + "/bert/deberta-v2-large-japanese",
    )
    move_cache()

    os.makedirs(CACHE_PATH + "/bert/bert-base-japanese-v3", exist_ok=True)
    snapshot_download(
        "cl-tohoku/bert-base-japanese-v3",
        local_dir=CACHE_PATH + "/bert/bert-base-japanese-v3",
    )
    move_cache()

    import nltk
    nltk.download('averaged_perceptron_tagger')
    nltk.download('cmudict')


image = (
    modal.Image.debian_slim(python_version="3.10")
        .pip_install(
            "librosa==0.9.2",
            "matplotlib",
            "numpy",
            "numba",
            "phonemizer",
            "scipy",
            "tensorboard",
            "Unidecode",
            "amfm_decompy",
            "jieba",
            "transformers",
            "pypinyin",
            "cn2an",
            "gradio==3.50.2",
            "av",
            "mecab-python3",
            "loguru",
            "unidic-lite",
            "cmudict",
            "fugashi",
            "num2words",
            "PyYAML",
            "requests",
            "pyopenjtalk-prebuilt",
            "jaconv",
            "psutil",
            "GPUtil",
            "vector_quantize_pytorch",
            "g2p_en",
            "sentencepiece",
            "pykakasi",
            "langid",
            "torch",
            "torchvision",
            "torchaudio",
        )
        .run_function(download_model_weights)
)

@stub.function(
    gpu="l4",
    image=image,
    retries=3,
    mounts=[
        modal.Mount.from_local_python_packages("config"),
        modal.Mount.from_local_python_packages("tools"),
        modal.Mount.from_local_python_packages("utils"),
        modal.Mount.from_local_python_packages("infer"),
        modal.Mount.from_local_python_packages("re_matching"),
        modal.Mount.from_local_python_packages("modal_const"),
        modal.Mount.from_local_python_packages("commons"),
        modal.Mount.from_local_python_packages("text"),
        modal.Mount.from_local_python_packages("models"),
        modal.Mount.from_local_python_packages("modules"),
        modal.Mount.from_local_python_packages("transforms"),
        modal.Mount.from_local_python_packages("attentions"),
        modal.Mount.from_local_python_packages("monotonic_align"),
        modal.Mount.from_local_python_packages("oldVersion"),
        modal.Mount.from_local_file("config.yml", CACHE_PATH + "/config.yml"),
        modal.Mount.from_local_file("bert/bert_models.json", CACHE_PATH + "/bert/bert_models.json"),
    ]
)
def speech(
    text,
    speaker,
    sdp_ratio,
    noise_scale,
    noise_scale_w,
    length_scale,
    language,
    style_text,
    style_weight,
):
    import torch
    import utils
    from infer import infer, latest_version, get_net_g
    import gradio as gr
    import numpy as np
    from config import config
    import wave
    import tempfile

    net_g = None
    device = config.webui_config.device
    hps = utils.get_hparams_from_file(config.webui_config.config_path)
    version = hps.version if hasattr(hps, "version") else latest_version
    net_g = get_net_g(
        model_path=config.webui_config.model, version=version, device=device, hps=hps
    )
    speaker_ids = hps.data.spk2id
    # speakers = list(speaker_ids.keys())
    # languages = ["ZH", "JP", "EN", "mix", "auto"]

    def generate_audio(
        text,
        sdp_ratio,
        noise_scale,
        noise_scale_w,
        length_scale,
        speaker,
        language,
        reference_audio,
        emotion,
        style_text,
        style_weight,
        skip_start=False,
        skip_end=False,
    ):
        slices = text.split("|")
        audio_list = []
        # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
        with torch.no_grad():
            for idx, piece in enumerate(slices):
                skip_start = idx != 0
                skip_end = idx != len(slices) - 1
                audio = infer(
                    piece,
                    reference_audio=reference_audio,
                    emotion=emotion,
                    sdp_ratio=sdp_ratio,
                    noise_scale=noise_scale,
                    noise_scale_w=noise_scale_w,
                    length_scale=length_scale,
                    sid=speaker,
                    language=language,
                    hps=hps,
                    net_g=net_g,
                    device=device,
                    skip_start=skip_start,
                    skip_end=skip_end,
                    style_text=style_text,
                    style_weight=style_weight,
                )
                audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
                audio_list.append(audio16bit)
        return np.concatenate(audio_list)

    res = generate_audio(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, speaker, language, None, "Happy", style_text, style_weight)
    data = res.tobytes()
    tempfd, temppath = tempfile.mkstemp()
    with wave.open(temppath, "wb") as wav_file:
        wav_file.setparams((1, 2, 44100, 0, 'NONE', 'NONE'))
        wav_file.writeframes(data)

    ret = b""
    with open(temppath, "rb") as wav_file:
        ret = wav_file.read()

    os.close(tempfd)

    return ret

class SpeechReq(BaseModel):
    text: str
    speaker: str
    sdp_ratio: float = 0.5
    noise_scale: float = 0.6
    noise_scale_w: float = 0.9
    length_scale: float = 1.0
    language: str = "ZH"
    style_text: str = ""
    style_weight: float = 0.7

@web_app.post("/submit")
async def submit(req: SpeechReq):
    speech = modal.Function.lookup("bert-vits2", "speech")
    call = speech.spawn(
        req.text,
        req.speaker,
        req.sdp_ratio,
        req.noise_scale,
        req.noise_scale_w,
        req.length_scale,
        req.language,
        req.style_text,
        req.style_weight,
    )
    return {"call_id": call.object_id}

@web_app.get("/result/{call_id}")
async def poll_results(call_id: str):
    from modal.functions import FunctionCall

    function_call = FunctionCall.from_id(call_id)
    try:
        result = function_call.get(timeout=0)
    except TimeoutError:
        return fastapi.responses.JSONResponse(content="Still running", status_code=202)

    return Response(content=result, media_type="audio/x-wav")

@stub.function()
@modal.asgi_app()
def wrapper():
    return web_app

如果不想管细节,只需要修改download_model_weights函数中下载模型文件部分即可。我在自己的服务器上托管了模型权重,你需要把model_files修改为你的模型配置、权重文件名称,然后修改HOSTURL为模型下载地址。如果模型是存在S3或者其他地方,按需修改即可。

然后将自己训练时的config.yml放到项目根目录下面,注意修改路径(这里没列出全部配置):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
dataset_path: "/root/model_cache"

# webui webui配置
# 注意, “:” 后需要加空格
webui:
  # 推理设备
  device: "cuda"
  # 模型路径
  model: "G_88000.pth"
  # 配置文件路径
  config_path: "config.json"
  # 端口号
  port: 7860
  # 是否公开部署,对外网开放
  share: false
  # 是否开启debug模式
  debug: false
  # 语种识别库,可选langid, fastlid
  language_identification_library: "langid"

我们的模型都下载到/root/model_cache文件夹里的,webui配置里面的模型路径只需要填模型名称即可。

搞定后在本地安装并配置modal(但是我推荐使用pipenv来管理虚拟环境):

1
2
python3 -m pip install modal
python3 -m modal setup

完成后部署即可:

1
modal deploy modal_deploy.py

谈谈细节

image定义了镜像,其中使用了debian slim镜像并指定Python版本为3.10,之后pip install了需要的包,随后调用我们的模型下载函数。镜像构建好后如果没有修改定义则不需要重新构建镜像,只需要同步代码即可,可以加快部署。

stub可以理解为一个app,一个项目之类的东西。我们在函数上加修饰可以定义一些细节,比如speech函数我们指定了GPU使用L4,使用我们之前构建的镜像,重试次数为3次。因为我们不是单个Python文件,所以这里通过mount将本地的包和配置文件给上传上去。

效果

部署好后在终端和Modal网页上可以看到https://***.modal.run的地址,我们使用Postman调用API:

提交到队列

这样我们的请求就进入到队列中等待处理了,可以拿submit返回的call id调用result接口获取结果:

获取结果

如果还没完成则会返回202。