#sdy Add extra `export_test.py` tests for using different meshes..
Under Shardy, we can: - use the same mesh on save and load - use one mesh on save and another mesh on load with different axis names - use one mesh on save and another mesh on load with different axis names and sizes. For this case Shardy propagation may not be optimal if the module doesn't specify out shardings. This is very hard to write a unit test for, and is rare to happen, and is something we have been considering adding in Shardy b/399957785. This will be something we can allow for during Shardy propagation. This can be a standalone fix in Shardy without making any changes to JAX or XLA. PiperOrigin-RevId: 729550047
Showing
- third_party/xla/xla/service/spmd/shardy/BUILD 1 addition, 0 deletionsthird_party/xla/xla/service/spmd/shardy/BUILD
- third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc 2 additions, 0 deletionsthird_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc
- third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD 18 additions, 0 deletionsthird_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD
- third_party/xla/xla/service/spmd/shardy/sdy_round_trip/dedup_meshes.cc 241 additions, 0 deletions...la/xla/service/spmd/shardy/sdy_round_trip/dedup_meshes.cc
- third_party/xla/xla/service/spmd/shardy/sdy_round_trip/dedup_meshes.h 43 additions, 0 deletions...xla/xla/service/spmd/shardy/sdy_round_trip/dedup_meshes.h
- third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc 2 additions, 0 deletions...y/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc
- third_party/xla/xla/service/spmd/shardy/test/dedup_meshes.mlir 74 additions, 0 deletions..._party/xla/xla/service/spmd/shardy/test/dedup_meshes.mlir
Please register or sign in to comment