Skip to content

Added Dockerfile for CI images & Upgrate CI to ROCm 7.2#195

Open
VeeraRajasekhar wants to merge 11 commits intodevfrom
dockerfile
Open

Added Dockerfile for CI images & Upgrate CI to ROCm 7.2#195
VeeraRajasekhar wants to merge 11 commits intodevfrom
dockerfile

Conversation

@VeeraRajasekhar
Copy link
Contributor

Description

Added the dockerfile, which can be used to create the ci-artifactory images.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added a new file docker/Dockerfile

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Collaborator

@wenchenvincent wenchenvincent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address the comments.

Copy link
Collaborator

@ipanfilo ipanfilo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why conversations are marked as resolved w/o any actual action?

@VeeraRajasekhar
Copy link
Contributor Author

Why conversations are marked as resolved w/o any actual action?

Some of them, I have resolved, some I have currently resolved in my local, just to keep track I will mark them resolved.

@wenchenvincent
Copy link
Collaborator

@VeeraRajasekhar Is this PR still needed?

@wenchenvincent
Copy link
Collaborator

@VeeraRajasekhar Could you remind me of what we had decided on this PR? It seemed that it is no longer relevant and we should close it.

@VeeraRajasekhar
Copy link
Contributor Author

Hi @ipanfilo, @wangye805

I have updated this PR with latest 7.2 docker file and moved to .github/scripts.

Let me know if I need to add an action to automate docker build and upload to our artifactory?

Thanks.

RUN pip install ipython pytest fire pydantic pybind11 ninja pandas
RUN apt-get update && apt-get install -y vim

ARG PYTORCH_ROCM_ARCH=gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx950;gfx1151
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this arg needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, when I manually installed torch I mentioned this, I can remove this now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

@@ -0,0 +1,27 @@
# TE CI Dockerfile
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put copyright here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed


ARG PYTORCH_ROCM_ARCH=gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx950;gfx1151

# Install flash-attention v2.8.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better make FA version arg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated



RUN pip install \
https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/jax_rocm7_pjrt-0.8.0%2Brocm7.2.0-py3-none-manylinux_2_28_x86_64.whl \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since base image is ARG, JAX wheels should also be built based on ARG cause their ROCm and Python versions should match. So in general there should be ARGs for ROCm version, base image, JAX version, JAX wheels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@VeeraRajasekhar
Copy link
Contributor Author

I had to force push to include new FA 2.8.3 support commit and my changes for 7.2 support to run the CI.

Thanks.

ROCm's jax.nn.scaled_matmul kernels require the contracting dimension (K)
to be at least 64. Without this validation, backward pass GEMMs with K < 64
cause segmentation faults.

Added K >= 64 check in _check_mxfp8_gemm_support() for JAX GEMM on ROCm.

Fixes: test_dense_grad_fp8[MXFP8_1D_SCALING-with_jax_gemm_True-64-32-64]
@VeeraRajasekhar
Copy link
Contributor Author

@Micky774, please review the following,

Analysis on testing on Jax & xla 0.8.2

(Not Supported) jax.nn.scaled_matmul (MXFP8) on ROCm crashes with a segmentation fault when the contracting dimension (K) is less than 64.

import functools
import jax
import jax.numpy as jnp
from jax import nn

key = jax.random.PRNGKey(0)
key_a, key_b = jax.random.split(key)
B, M, N, K = 1, 128, 128, 32


lhs = jax.random.normal(key_a, (B, M, K), dtype=jnp.float32)
rhs = jax.random.normal(key_b, (B, N, K), dtype=jnp.float32)

# 1. high-precision matmul
ref = jnp.einsum("bmk,bnk->bmn", lhs, rhs)

# 2. mxfp8 matmul
configs = [nn.get_scaled_dot_general_config("mxfp8")] * 3
scaled_dot = functools.partial(
    nn.scaled_dot_general,
    configs=configs,
    preferred_element_type=jnp.float32,
)

out = scaled_dot(lhs, rhs, (((2,), (2,)), ((0,), (0,))))

# compare results
print("high-precision ref: ")
print(ref)

print("mxfp8 out: ")
print(out)

max_abs = jnp.max(jnp.abs(out - ref))
max_rel = max_abs / jnp.max(jnp.abs(ref))

print("max abs error:", max_abs)
print("max rel error:", max_rel)

[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x32x64_UR_2: K must be a multiple of workgroupTile.k=64 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x16x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x16x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x16x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x16x256_UR_2: K must be a multiple of workgroupTile.k=256 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_128x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x128x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_128x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x128x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_16x64x256_UR_2: K must be a multiple of workgroupTile.k=256 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_192x32x128_UR_2: M must be a multiple of workgroupTile.m=192
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_192x32x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_128x128x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x192x128_UR_2: N must be a multiple of workgroupTile.n=192
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_32x192x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_192x64x128_UR_2: M must be a multiple of workgroupTile.m=192
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_192x64x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x192x128_UR_2: N must be a multiple of workgroupTile.n=192
[rr/error] Predicate mismatch for RR_GEMM_TN_FP8_FP8_Float_Float_Float_SA_B_SB_B_WGT_64x192x128_UR_2: K must be a multiple of workgroupTile.k=128 * unrollK=1
Segmentation fault (core dumped)


@VeeraRajasekhar
Copy link
Contributor Author

@Micky774
Copy link
Contributor

Micky774 commented Feb 16, 2026

@Micky774, please review the following,

Analysis on testing on Jax & xla 0.8.2

(Not Supported) jax.nn.scaled_matmul (MXFP8) on ROCm crashes with a segmentation fault when the contracting dimension (K) is less than 64.
...

This is a failure on certain configs for hipblaslt, which we already have tickets open for. We don't support these configs in TE anyways, so it's a known issue and not a blocker. Shouldn't be a problem for this PR.

@VeeraRajasekhar VeeraRajasekhar changed the title Added Dockerfile for CI images Added Dockerfile for CI images & Upgrate CI to ROCm 7.2 Feb 16, 2026
https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/jaxlib-${JAX_VERSION}%2Brocm${ROCM_VERSION}.0-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux_2_27_x86_64.whl

WORKDIR /workspace/
CMD ["/bin/bash"] No newline at end of file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRLF

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

RUN apt-get update && apt-get install -y vim

# Install flash-attention
ENV GPU_ARCHS=gfx90a;gfx950;gfx942
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CI purposes gfx90a is not needed. It is OK to keep if does not affect build time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants

Comments