Create new PjRt GPU client with remote devices after coordination service agent is available
This does not modify the previous PjRt GPU client. The test environment case where multiple threads are used to simulate multiple workers is supported. For the PjRt GPU MultiWorkerMirroredStrategy (MWMS) case where there are multiple workers each with a GPU or GPUs, the symptom of a client without this fix is a error message like the following where the number at the end is the ID of the first remote device, e.g. 8 when there are eight local GPUs with IDs 0..7. (Another example is 1 when there is one local GPU with ID 0.) `INVALID_ARGUMENT: No matching device found for device_id 8` Note that while the primary purpose of `BaseGPUDeviceFactory::CreateDevices` is to do one-time initialization, it is often called multiple times. In the typical production MWMS case, it is called both when creating a TF Context and when creating GRPC servers when enabling collectives. In the unit test added by this CL (which uses two GPUs), it is first called when a TF Context is created when starting up the test environment (with both GPUS) and then two worker TF processes are created which both call it twice during MWMS startup (each with the one GPU assigned to the process). Other test cases have different patterns. PiperOrigin-RevId: 611621382
Showing
- tensorflow/core/common_runtime/eager/BUILD 42 additions, 1 deletiontensorflow/core/common_runtime/eager/BUILD
- tensorflow/core/common_runtime/eager/context_distributed_manager.cc 340 additions, 23 deletions.../core/common_runtime/eager/context_distributed_manager.cc
- tensorflow/core/common_runtime/gpu/gpu_device.cc 111 additions, 26 deletionstensorflow/core/common_runtime/gpu/gpu_device.cc
- tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc 7 additions, 0 deletionstensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
- tensorflow/core/tfrt/common/BUILD 6 additions, 0 deletionstensorflow/core/tfrt/common/BUILD
- tensorflow/core/tfrt/common/pjrt_state.cc 13 additions, 0 deletionstensorflow/core/tfrt/common/pjrt_state.cc
- tensorflow/core/tfrt/common/pjrt_state.h 25 additions, 0 deletionstensorflow/core/tfrt/common/pjrt_state.h
- tensorflow/core/tfrt/common/pjrt_util.cc 33 additions, 0 deletionstensorflow/core/tfrt/common/pjrt_util.cc
- tensorflow/core/tfrt/common/pjrt_util.h 5 additions, 0 deletionstensorflow/core/tfrt/common/pjrt_util.h
- tensorflow/python/distribute/BUILD 25 additions, 0 deletionstensorflow/python/distribute/BUILD
- tensorflow/python/distribute/mwms_pjrt_gpu_test.py 116 additions, 0 deletionstensorflow/python/distribute/mwms_pjrt_gpu_test.py
- third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc 4 additions, 4 deletionsthird_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc
- third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h 10 additions, 0 deletionsthird_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h
Please register or sign in to comment