Added Dockerfile for CI images & Upgrate CI to ROCm 7.2#195
Added Dockerfile for CI images & Upgrate CI to ROCm 7.2#195VeeraRajasekhar wants to merge 11 commits intodevfrom
Conversation
wenchenvincent
left a comment
There was a problem hiding this comment.
Please address the comments.
ipanfilo
left a comment
There was a problem hiding this comment.
Why conversations are marked as resolved w/o any actual action?
Some of them, I have resolved, some I have currently resolved in my local, just to keep track I will mark them resolved. |
|
@VeeraRajasekhar Is this PR still needed? |
|
@VeeraRajasekhar Could you remind me of what we had decided on this PR? It seemed that it is no longer relevant and we should close it. |
c4913b2 to
e49b365
Compare
|
Hi @ipanfilo, @wangye805 I have updated this PR with latest 7.2 docker file and moved to .github/scripts. Let me know if I need to add an action to automate docker build and upload to our artifactory? Thanks. |
.github/scripts/Dockerfile.ci.deps
Outdated
| RUN pip install ipython pytest fire pydantic pybind11 ninja pandas | ||
| RUN apt-get update && apt-get install -y vim | ||
|
|
||
| ARG PYTORCH_ROCM_ARCH=gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx950;gfx1151 |
There was a problem hiding this comment.
Not needed, when I manually installed torch I mentioned this, I can remove this now.
.github/scripts/Dockerfile.ci.deps
Outdated
| @@ -0,0 +1,27 @@ | |||
| # TE CI Dockerfile | |||
.github/scripts/Dockerfile.ci.deps
Outdated
|
|
||
| ARG PYTORCH_ROCM_ARCH=gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx950;gfx1151 | ||
|
|
||
| # Install flash-attention v2.8.1 |
There was a problem hiding this comment.
Better make FA version arg
.github/scripts/Dockerfile.ci.deps
Outdated
|
|
||
|
|
||
| RUN pip install \ | ||
| https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/jax_rocm7_pjrt-0.8.0%2Brocm7.2.0-py3-none-manylinux_2_28_x86_64.whl \ |
There was a problem hiding this comment.
Since base image is ARG, JAX wheels should also be built based on ARG cause their ROCm and Python versions should match. So in general there should be ARGs for ROCm version, base image, JAX version, JAX wheels
e49b365 to
bdc75a2
Compare
|
I had to force push to include new FA 2.8.3 support commit and my changes for 7.2 support to run the CI. Thanks. |
ROCm's jax.nn.scaled_matmul kernels require the contracting dimension (K) to be at least 64. Without this validation, backward pass GEMMs with K < 64 cause segmentation faults. Added K >= 64 check in _check_mxfp8_gemm_support() for JAX GEMM on ROCm. Fixes: test_dense_grad_fp8[MXFP8_1D_SCALING-with_jax_gemm_True-64-32-64]
|
@Micky774, please review the following, Analysis on testing on Jax & xla 0.8.2(Not Supported) jax.nn.scaled_matmul (MXFP8) on ROCm crashes with a segmentation fault when the contracting dimension (K) is less than 64. |
|
https://github.com/ROCm/TransformerEngine/actions/runs/22024090773/job/63637830333 Level=3 testing had no issues. |
This is a failure on certain configs for hipblaslt, which we already have tickets open for. We don't support these configs in TE anyways, so it's a known issue and not a blocker. Shouldn't be a problem for this PR. |
.github/scripts/Dockerfile.ci.deps
Outdated
| https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jaxlib-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux_2_27_x86_64.whl | ||
|
|
||
| WORKDIR /workspace/ | ||
| CMD ["/bin/bash"] No newline at end of file |
.github/scripts/Dockerfile.ci.deps
Outdated
| RUN apt-get update && apt-get install -y vim | ||
|
|
||
| # Install flash-attention | ||
| ENV GPU_ARCHS=gfx90a;gfx950;gfx942 |
There was a problem hiding this comment.
For CI purposes gfx90a is not needed. It is OK to keep if does not affect build time
Description
Added the dockerfile, which can be used to create the ci-artifactory images.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: