Skip to content

GLU Operator gives different Results on Dml EP compared to CPU EP #24311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
spgoswami1 opened this issue Apr 4, 2025 · 3 comments
Open

GLU Operator gives different Results on Dml EP compared to CPU EP #24311

spgoswami1 opened this issue Apr 4, 2025 · 3 comments
Labels
ep:DML issues related to the DirectML execution provider

Comments

@spgoswami1
Copy link

spgoswami1 commented Apr 4, 2025

Describe the issue

Here is sample code. The output from Dml EP and CPU EP are completely different:

import torch
import torch.onnx
import onnxruntime as ort
import numpy as np

class GLUModel(torch.nn.Module):
    def __init__(self):
        super(GLUModel, self).__init__()
        self.glu = torch.nn.GLU(dim=1)

    def forward(self, x):
        return self.glu(x)

model = GLUModel()
dummy_input = torch.randn(1, 4, 5) 

onnx_path = "glu_model.onnx"
torch.onnx.export(model, dummy_input, onnx_path, input_names=["input"], output_names=["output"])

session_cpu = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
session_dml = ort.InferenceSession(onnx_path, providers=["DmlExecutionProvider"])

input_data = {"input": dummy_input.numpy()}

output_cpu = session_cpu.run(None, input_data)[0]
output_dml = session_dml.run(None, input_data)[0]

print(output_cpu)
print(output_dml)
[[[-0.17137252  0.03492244  1.0936418  -0.5506519  -0.14124376]
  [-0.6255165  -0.07267396  0.4950865  -0.34862173  0.08689743]]]
[[[-0.27777183  1.6589377   0.23191208 -0.1863674  -0.1662567 ]
  [ 0.674737   -0.24098903  0.43470204 -0.25040996  0.01485676]]]

To reproduce

Run the above code

Urgency

Low, I have figured this out and written a manual implementation of it but it was tough to figure out why Dml output and CPU output were very off.

class ManualGLU(nn.Module):
    def forward(self, x):
        split_size = x.shape[-1] // 2
        a, b = torch.split(x, split_size, dim=-1)
        
        # **Manual sigmoid: sigmoid(x) = 1 / (1 + exp(-x))**
        b_sigmoid = 1 / (1 + torch.exp(-b))
        
        return a * b_sigmoid`

Platform

Windows

OS Version

10.0.26100 Build 26100

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.21.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

DirectML

Execution Provider Library Version

No response

@github-actions github-actions bot added the ep:DML issues related to the DirectML execution provider label Apr 4, 2025
@fdwr
Copy link
Contributor

fdwr commented Apr 4, 2025

Can you attach the generated tiny .onnx file to the issue to see the exact graph? Do results differ using your decomposition too?

@spgoswami1
Copy link
Author

spgoswami1 commented Apr 7, 2025

Here is the onnx file
https://drive.google.com/file/d/1BkVrYoM-zBF_uhv4Jcp9EyPS_fvnLS4x/view?usp=sharing

Results are correct after using ManualGLU. Here is onnx file for manual GLU
https://drive.google.com/file/d/10jlB791-qEQcDzuyKp0FCgO1kO8Skf9I/view?usp=sharing

@spgoswami1
Copy link
Author

DML and CPU Ep output for Sigmoid is also same. Does that mean there is any layout changes in the output of sigmoid between CPU EP and DML EP?

It surely seems like, if we have input of shape (1,1,2) with this as an input tensor([[[-0.2934, -0.2934]]]), then the output from both the EP is exactly same.

So, due to this layout change, wrong order of elements get multiplied during the Mul Op.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:DML issues related to the DirectML execution provider
Projects
None yet
Development

No branches or pull requests

2 participants