VERTEX_TRAINING_SPEC = {
"project": GOOGLE_CLOUD_PROJECT,
"worker_pool_specs": [
{
"machine_spec": {
"machine_type": GOOGLE_CLOUD_MACHINE_TYPE,
},
"replica_count": 1,
"container_spec": {
"image_uri": tfx_image,
},
}
],
}
trainer = tfx.extensions.google_cloud_ai_platform.Trainer(
...
custom_config={
tfx.extensions.google_cloud_ai_platform.ENABLE_UCAIP_KEY: True,
tfx.extensions.google_cloud_ai_platform.UCAIP_REGION_KEY: YOUR_GOOGLE_CLOUD_REGION,
tfx.extensions.google_cloud_ai_platform.TRAINING_ARGS_KEY: VERTEX_TRAINING_SPEC,
},
)