Skip to content

Commit 5ce3639

Browse files
committedMar 10, 2025·
new design of t5 onnx
1 parent 49328fe commit 5ce3639

File tree

12 files changed

+1145
-580
lines changed

12 files changed

+1145
-580
lines changed
 

‎onnxruntime/contrib_ops/cpu/transformers/beam_search.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,19 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
139139
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
140140
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
141141

142-
if (parameters_->decoder_start_token_id < 0) {
143-
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
144-
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
142+
if (t5_encoder_subgraph_->HasLogitsOutput()) {
143+
// New format requires start token id.
144+
ORT_ENFORCE(parameters_->decoder_start_token_id >= 0);
145145
} else {
146-
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
147-
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
146+
if (parameters_->decoder_start_token_id < 0) {
147+
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
148+
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
149+
} else {
150+
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
151+
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
152+
}
148153
}
154+
149155
} else if (attribute_name == "decoder") {
150156
ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
151157
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");

‎onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ class BeamSearchT5 : public BeamSearchBase<T> {
5151
expand_buffer_int32_func_(expand_buffer_int32_func),
5252
expand_buffer_float_func_(expand_buffer_float_func),
5353
expand_buffer_float16_func_(expand_buffer_float16_func),
54-
create_beam_scorer_func_(create_beam_scorer_func) {}
54+
create_beam_scorer_func_(create_beam_scorer_func) {
55+
// When decoder uses encoder_hidden_state, make sure the encoder outputs it.
56+
if (decoder_subgraph_.UseEncoderHiddenState()) {
57+
ORT_ENFORCE(encoder_subgraph_.subgraph_output_names[1] == "encoder_hidden_states");
58+
}
59+
ORT_ENFORCE(encoder_subgraph_.num_layers == decoder_subgraph_.num_layers);
60+
}
5561

5662
#ifdef USE_CUDA
5763
Status InitializeCuda(
@@ -160,7 +166,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
160166
this->create_encoder_inputs_func_,
161167
this->add_to_feeds_func_,
162168
buffer,
163-
decoder_input_ids,
169+
decoder_input_ids, // new format does not use decoder_input_ids in encoder, it is still initialized here when decoder_start_token_id >= 0.
164170
this->ort_stream_));
165171

166172
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
@@ -233,15 +239,20 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
233239

234240
std::vector<OrtValue> decoder_fetches;
235241

236-
if (current_length + 1 < parameters->max_length) {
242+
// When encoder outputs logits (in old format), we need get the next token from logits.
243+
if (current_length + 1 < parameters->max_length && encoder_subgraph_.HasLogitsOutput()) {
237244
++iteration_counter;
238-
ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0],
245+
const OrtValue& logits = encoder_fetches[0];
246+
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
239247
beam_next_tokens,
240248
beam_state,
241249
cpu_state,
242250
iteration_counter));
243251
++current_length; // Increase sequence length after a new token is generated.
252+
}
244253

254+
// Generate inputs for next decoder subgraph call.
255+
if (current_length < parameters->max_length) {
245256
ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(this->cpu_allocator_,
246257
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
247258
this->implicit_inputs_,
@@ -262,6 +273,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
262273
this->cuda_device_prop_ != nullptr));
263274

264275
if (decoder_subgraph_.past_present_share_buffer_) {
276+
// Configure buffer sharing of past and present kv cache.
265277
decoder_fetches.reserve(static_cast<size_t>(decoder_subgraph_.GetFirstPresentOutputIndex()) +
266278
2 * static_cast<size_t>(decoder_subgraph_.num_layers));
267279
decoder_fetches.resize(decoder_subgraph_.GetFirstPresentOutputIndex(), OrtValue());

‎onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@ Subgraph::Subgraph(
3636
auto& subgraph_inputs = subgraph.GetInputs();
3737
auto& subgraph_outputs = subgraph.GetOutputs();
3838

39-
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
40-
// outputs: logits, present_0, present_1, ...
4139
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
4240
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
4341

44-
// CheckSubgraph will verify inputs and outputs later.
4542
subgraph_input_names.reserve(num_subgraph_inputs);
4643
for (int i = 0; i < num_subgraph_inputs; ++i) {
4744
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
@@ -68,10 +65,9 @@ Status Subgraph::Setup(const SessionState& session_state,
6865
InlinedVector<std::string_view> feed_names;
6966
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
7067

71-
// Use the first output (logits) to find device location.
68+
// Use the first output to find device location.
7269
const OrtDevice& default_location = utils::FindDeviceForValue(subgraph_session_state, subgraph_output_names[0]);
7370

74-
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
7571
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
7672

7773
const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
@@ -174,13 +170,15 @@ Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shap
174170
}
175171

176172
// Logits shape is like (batch_size, seq_len, vocabulary_size)
177-
ORT_RETURN_IF(logits_shape->dim_size() != 3,
178-
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
173+
if (logits_shape != nullptr) {
174+
ORT_RETURN_IF(logits_shape->dim_size() != 3,
175+
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
179176

180-
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
181-
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
177+
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
178+
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
182179

183-
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
180+
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
181+
}
184182

185183
return Status::OK();
186184
}

‎onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc

Lines changed: 83 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,23 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
141141
}
142142

143143
// Create inputs for decoder from the following data sources:
144-
// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
145-
// encoder fetches: logits,
146-
// encoder_hidden_states,
147-
// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
148-
// decoder_feeds: input_ids,
149-
// encoder_attention_mask,
150-
// encoder_hidden_states,
151-
// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
144+
// New format:
145+
// encoder feeds: encoder_input_ids, encoder_attention_mask
146+
// encoder fetches: present_key_cross_0, present_value_cross_0, ...
147+
// decoder_feeds: input_ids, encoder_attention_mask,
148+
// present_key_self_0, present_value_self_0, ...,
149+
// present_key_cross_0, present_value_cross_0, ...
150+
151+
// Old format:
152+
// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
153+
// encoder fetches: logits, encoder_hidden_states,
154+
// present_key_self_0, present_value_self_0, ...,
155+
// present_key_cross_0, present_value_cross_0, ...
156+
// decoder_feeds: input_ids, encoder_input_ids (optiona), encoder_attention_mask, encoder_hidden_states,
157+
// present_key_self_0, present_value_self_0, ...,
158+
// present_key_cross_0, present_value_cross_0, ...
159+
// past_seq_len (optional), num_beams (optional), cache_indirection (optional)
160+
152161
Status T5DecoderSubgraph::CreateInitialFeeds(
153162
AllocatorPtr cpu_allocator,
154163
gsl::span<const int32_t> beam_next_tokens,
@@ -173,33 +182,30 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
173182
// Allocate subgraph inputs from same device as inputs of encoder subgraph.
174183
AllocatorPtr allocator = session_state_->GetAllocator(encoder_feeds[0].Get<Tensor>().Location());
175184

185+
int batch_beam_size = static_cast<int>(encoder_fetches[0].Get<Tensor>().Shape()[0]) * num_beam;
186+
176187
// Copy beam next tokens in CPU to input_ids in provider device (CPU for CPU EP, or GPU for CUDA EP).
177-
int batch_beam_size = static_cast<int>(beam_next_tokens.size());
178188
int sequence_length = !use_sequence_as_input_ids ? 1 : cur_len;
179189
int64_t dims[] = {batch_beam_size, sequence_length};
180190
TensorShape input_ids_shape(&dims[0], 2);
181191
OrtValue input_ids;
182192
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);
183-
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
184-
AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
185-
size_t total_size = static_cast<size_t>(cur_len) * static_cast<size_t>(batch_beam_size);
186-
size_t total_size_bytes = total_size * sizeof(int);
187-
auto seq_copy = IAllocator::MakeUniquePtr<int>(buffer_allocator, total_size_bytes, false, stream);
188-
int* seq_copy_ptr = seq_copy.get();
189-
190-
if (!use_sequence_as_input_ids_) {
193+
194+
// Prepare data for input_ids.
195+
if (!use_sequence_as_input_ids_) { // use next tokens for input_ids. This is for Whisper model.
191196
ORT_RETURN_IF_ERROR(device_copy_int32_func(
192197
input_ids.GetMutable<Tensor>()->MutableDataAsSpan<int32_t>(),
193198
beam_next_tokens,
194199
stream,
195200
DeviceCopyDirection::hostToDevice));
196201
} else {
202+
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
197203
if (use_cuda) {
198204
auto sequences_buffer = sequences.GetCurrentDeviceSequences();
199205
for (int i = 0; i < batch_beam_size; i++) {
200-
size_t batch_beam_stride = static_cast<size_t>(i) * static_cast<size_t>(sequences.GetMaxLength());
206+
size_t offset = static_cast<size_t>(i) * static_cast<size_t>(sequences.GetMaxLength());
201207
int seq_size = sequences.GetSequenceLength();
202-
gsl::span<const int32_t> sequence = sequences_buffer.subspan(batch_beam_stride, seq_size);
208+
gsl::span<const int32_t> sequence = sequences_buffer.subspan(offset, seq_size);
203209
gsl::span<int> temp_input(input_ids_data + static_cast<ptrdiff_t>(i) * seq_size, seq_size);
204210
ORT_RETURN_IF_ERROR(device_copy_int32_func(
205211
temp_input,
@@ -208,6 +214,13 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
208214
DeviceCopyDirection::deviceToDevice));
209215
}
210216
} else {
217+
size_t total_size = static_cast<size_t>(cur_len) * static_cast<size_t>(batch_beam_size);
218+
size_t total_size_bytes = total_size * sizeof(int);
219+
AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
220+
// TODO: not need extra buffer. Copy directly to input_ids_data instead like the user_cuda above.
221+
auto seq_copy = IAllocator::MakeUniquePtr<int>(buffer_allocator, total_size_bytes, false, stream);
222+
int* seq_copy_ptr = seq_copy.get();
223+
211224
const size_t cur_len_bytes = cur_len * sizeof(int);
212225
for (int i = 0; i < batch_beam_size; i++) {
213226
gsl::span<const int32_t> sequence = sequences.GetSequence(i);
@@ -227,9 +240,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
227240

228241
// The ordering is the same as used in Setup.
229242
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
243+
244+
// input 0: input_ids
230245
decoder_feeds.push_back(input_ids);
231246

232-
if (has_encoder_input_ids_) {
247+
if (has_encoder_input_ids_) { // encoder_input_ids is optional
233248
// The encoder_input_ids is copied from the first input of encoder.
234249
OrtValue expanded_encoder_input_ids;
235250
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
@@ -251,70 +266,65 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
251266
expanded_decoder_attention_masks,
252267
false,
253268
0 /*max_sequence_length*/));
254-
255269
decoder_feeds.push_back(expanded_decoder_attention_masks);
256270

257271
if (!past_present_share_buffer_) {
258272
past_present_share_buffer_max_seq_len = 0;
259273
}
260274

261-
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
262-
// of encoder.
263-
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
264-
// TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.
265-
// What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
266-
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) {
267-
if (j == 1) {
268-
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
269-
OrtValue expanded_hidden_states;
270-
if (is_output_float16_) {
271-
ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream,
272-
encoder_fetches[j],
273-
num_beam,
274-
allocator,
275-
expanded_hidden_states,
276-
false,
277-
0 /*max_sequence_length*/));
278-
} else {
279-
ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream,
280-
encoder_fetches[j],
281-
num_beam,
282-
allocator,
283-
expanded_hidden_states,
284-
false,
285-
0 /*max_sequence_length*/));
286-
}
287-
decoder_feeds.push_back(expanded_hidden_states);
288-
} else {
275+
// macro to expand encoder outputs and append to decoder feeds.
276+
#define ADD_DECODER_FEED(encoder_output, is_dynamic_kv_cache) \
277+
OrtValue expanded; \
278+
if (is_output_float16_) { \
279+
ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, encoder_output, num_beam, allocator, expanded, false, \
280+
is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0)); \
281+
} else { \
282+
ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, encoder_output, num_beam, allocator, expanded, false, \
283+
is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0)); \
284+
} \
285+
decoder_feeds.push_back(expanded);
286+
287+
// The encoder_hidden_states is copied from the second output of encoder.
288+
if (has_hidden_state_) {
289+
ADD_DECODER_FEED(encoder_fetches[1], false);
290+
}
291+
292+
// New format of encoder has only cross outputs.
293+
bool is_new_format = (static_cast<int>(encoder_fetches.size()) == 2 * num_layers);
294+
if (is_new_format) {
295+
for (int i = 0; i < 2 * num_layers; i++) {
296+
// cross shape is (batch_size, num_heads, encode_sequence_length, head_size)
297+
const TensorShape& cross_shape = encoder_fetches[0].Get<Tensor>().Shape();
298+
ORT_ENFORCE(cross_shape.NumDimensions() == 4);
299+
300+
// Shape for kv cache: (batch_size * num_beam, num_heads, max_seq_len, head_size)
301+
int64_t cache_dims[4] = {0};
302+
cross_shape.CopyDims(cache_dims, cross_shape.NumDimensions());
303+
cache_dims[0] *= num_beam;
304+
cache_dims[2] = past_present_share_buffer_max_seq_len;
305+
TensorShape expanded_shape(&cache_dims[0], cross_shape.NumDimensions());
306+
307+
MLDataType element_type = encoder_fetches[0].Get<Tensor>().DataType();
308+
OrtValue past;
309+
Tensor::InitOrtValue(element_type, expanded_shape, allocator, past);
310+
decoder_feeds.push_back(past);
311+
}
312+
313+
// Add cross inputs from encoder output.
314+
for (size_t j = 0; j < encoder_fetches.size(); j++) {
315+
ADD_DECODER_FEED(encoder_fetches[j], false);
316+
}
317+
} else {
318+
for (size_t j = 1 + has_hidden_state_; j < encoder_fetches.size(); j++) {
289319
// past key/value for cross attention does not need to be initialized with max_seq_len since they are static.
290-
bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast<size_t>(num_layers);
291-
292-
OrtValue expanded_cache;
293-
if (is_output_float16_) {
294-
ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream,
295-
encoder_fetches[j],
296-
num_beam,
297-
allocator,
298-
expanded_cache,
299-
false,
300-
use_max_seq_len ? past_present_share_buffer_max_seq_len : 0));
301-
} else {
302-
ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream,
303-
encoder_fetches[j],
304-
num_beam,
305-
allocator,
306-
expanded_cache,
307-
false,
308-
use_max_seq_len ? past_present_share_buffer_max_seq_len : 0));
309-
}
310-
decoder_feeds.push_back(expanded_cache);
320+
bool is_dynamic_kv_cache = (j - first_past_input_index_) < 2 * static_cast<size_t>(num_layers);
321+
ADD_DECODER_FEED(encoder_fetches[j], is_dynamic_kv_cache);
311322
}
312323
}
313324

314-
// TODO: This part shares the similar logic with CreateInitialFeeds() in subgraph_gpt.cc. We should refactor it.
315325
if (past_present_share_buffer_) {
316-
// Past sequence length feed
317-
ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, 1));
326+
// Past sequence length set to 0
327+
ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, is_new_format ? 0 : 1));
318328
// Add beam search specific inputs
319329
if (need_cache_indir) {
320330
const int64_t batch_size = static_cast<int64_t>(batch_beam_size / num_beam);

‎onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ class T5DecoderSubgraph : public Subgraph {
7272
return use_sequence_as_input_ids_;
7373
}
7474

75+
inline bool UseEncoderHiddenState() const {
76+
return has_hidden_state_;
77+
}
78+
7579
protected:
7680
int first_past_input_index_;
7781
int first_present_output_index_;

‎onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc

Lines changed: 72 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,78 +15,105 @@ namespace transformers {
1515

1616
/* T5 Encoder Subgraph (It also contains decoder initialization where decoder_input_ids are filled with start token ID).
1717
18-
Inputs:
18+
New format:
19+
Inputs:
1920
encoder_input_ids: int32 (B, encode_sequence_length)
2021
encoder_attention_mask: int32 (B, encode_sequence_length)
21-
decoder_input_ids: int32 (B, 1)
2222
2323
Outputs:
24-
logits: (B, 1, vocab_size)
25-
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
26-
27-
present_key_self_0: (B, num_heads, 1, head_size)
28-
present_value_self_0: (B, num_heads, 1, head_size)
29-
... (for each self attention layer)
30-
3124
present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
3225
present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
3326
... (for each cross attention layer)
3427
35-
Note:
36-
Here, B = batch_size * num_beams since we expand the inputs.
37-
Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams.
38-
Data type of input or output is float or float16 if not specified.
28+
Old format:
29+
Inputs:
30+
encoder_input_ids: int32 (B, encode_sequence_length)
31+
encoder_attention_mask: int32 (B, encode_sequence_length)
32+
decoder_input_ids: int32 (B, 1)
33+
34+
Outputs:
35+
logits: (B, 1, vocab_size)
36+
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
37+
38+
present_key_self_0: (B, num_heads, 1, head_size)
39+
present_value_self_0: (B, num_heads, 1, head_size)
40+
... (for each self attention layer)
41+
42+
present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
43+
present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
44+
... (for each cross attention layer)
45+
46+
Note:
47+
Here, B = batch_size * num_beams since we expand the inputs.
48+
Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams.
49+
Data type of input or output is float or float16 if not specified.
3950
*/
4051

4152
Status T5EncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
4253
const std::vector<const NodeArg*>& subgraph_outputs) {
43-
ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs);
44-
45-
ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs);
46-
ORT_RETURN_IF((static_cast<int>(subgraph_outputs.size()) - first_present_output_index_) % 4 != 0,
47-
"number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs);
54+
constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
55+
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
56+
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
4857

58+
ORT_RETURN_IF(num_subgraph_inputs != 2 && num_subgraph_inputs != 3, "expect 2 or 3 inputs, got:", num_subgraph_inputs);
4959
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "encoder_input_ids",
5060
"encoder subgraph input 0 shall be named as encoder_input_ids, got: ", subgraph_inputs[0]->Name());
5161
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
5262
"encoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name());
53-
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids",
54-
"encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name());
55-
56-
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
57-
"encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
58-
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states",
59-
"encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name());
60-
ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0",
61-
"encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name());
62-
ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0",
63-
"encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name());
64-
65-
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape();
66-
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
67-
68-
// Save parameters related to the subgraph.
69-
ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false));
70-
num_layers = (static_cast<int>(subgraph_outputs.size()) - first_present_output_index_) / 4;
71-
72-
constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
73-
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
74-
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
7563

7664
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
7765
"encoder subgraph input 0 (encoder_input_ids) shall have int32 type");
7866
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
7967
"encoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
80-
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
81-
"encoder subgraph input 2 (decoder_input_ids) shall have int32 type");
68+
69+
if (num_subgraph_inputs == 2) {
70+
ORT_RETURN_IF(num_subgraph_outputs < 2 || num_subgraph_outputs % 2 != 0,
71+
"number of outputs expected to be 2 * layers, got:", num_subgraph_outputs);
72+
73+
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "present_key_cross_0",
74+
"encoder subgraph output 0 shall be named as present_key_cross_0, got: ", subgraph_outputs[0]->Name());
75+
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_value_cross_0",
76+
"encoder subgraph output 1 shall be named as present_value_cross_0, got: ", subgraph_outputs[1]->Name());
77+
78+
// Deduce num_heads and head_size parameters from shape of graph outputs
79+
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[0]->Shape();
80+
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = nullptr;
81+
ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false));
82+
83+
num_layers = num_subgraph_outputs / 2;
84+
} else {
85+
ORT_RETURN_IF(num_subgraph_outputs < 6 || (num_subgraph_outputs - first_present_output_index_) % 4 != 0,
86+
"number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs);
87+
88+
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids",
89+
"encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name());
90+
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
91+
"encoder subgraph input 2 (decoder_input_ids) shall have int32 type");
92+
93+
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
94+
"encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
95+
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states",
96+
"encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name());
97+
ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0",
98+
"encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name());
99+
ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0",
100+
"encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name());
101+
102+
// Deduce num_heads, head_size and vocab_size from shape of graph outputs
103+
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape();
104+
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
105+
ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false));
106+
107+
num_layers = (num_subgraph_outputs - first_present_output_index_) / 4;
108+
}
82109

83110
auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type();
84111
ORT_RETURN_IF(output_type != float32_type && output_type != float16_type,
85112
"encoder subgraph output 0 (logits) shall be float or float16 data type");
86113

87114
for (int i = 1; i < num_subgraph_outputs; i++) {
88115
ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != output_type,
89-
"encoder subgraph outputs 1, 2, ... shall have same data type");
116+
"encoder subgraph outputs shall have same data type");
90117
}
91118

92119
is_output_float16_ = (output_type == float16_type);
@@ -120,7 +147,6 @@ Status T5EncoderSubgraph::CreateInitialFeeds(
120147
}
121148
ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr");
122149

123-
// TODO(tianleiwu): expand the outputs instead of inputs to save computation.
124150
OrtValue encoder_input_ids;
125151
OrtValue encoder_attention_mask;
126152
ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&original_encoder_input_ids,
@@ -136,9 +162,10 @@ Status T5EncoderSubgraph::CreateInitialFeeds(
136162
AllocatorPtr default_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeDefault));
137163
AllocatorPtr pinned_allocator = session_state_->GetAllocator(provider->GetOrtDeviceByMemType(OrtMemTypeCPU));
138164
const OrtMemoryInfo& location = default_allocator->Info();
165+
139166
ORT_RETURN_IF_ERROR(add_to_feeds_func(
140167
ort_stream,
141-
{encoder_input_ids, encoder_attention_mask, decoder_input_ids},
168+
num_subgraph_inputs == 2 ? std::initializer_list<OrtValue>{encoder_input_ids, encoder_attention_mask} : std::initializer_list<OrtValue>{encoder_input_ids, encoder_attention_mask, decoder_input_ids},
142169
feeds,
143170
buffer,
144171
default_allocator,

‎onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class T5EncoderSubgraph : public Subgraph {
1616
const onnxruntime::Node& node_in,
1717
const std::string& attribute_name,
1818
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {
19-
first_present_output_index_ = 2;
19+
first_present_output_index_ = HasLogitsOutput() ? 2 : 0;
2020
}
2121

2222
// Create inputs for first inference of subgraph.
@@ -36,9 +36,15 @@ class T5EncoderSubgraph : public Subgraph {
3636
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
3737
const std::vector<const NodeArg*>& subgraph_outputs) override;
3838

39+
#ifdef DEBUG_GENERATION
3940
int GetFirstPresentOutputIndex() const {
4041
return first_present_output_index_;
4142
}
43+
#endif
44+
45+
bool HasLogitsOutput() const {
46+
return num_subgraph_inputs != 2;
47+
}
4248

4349
protected:
4450
int first_present_output_index_;

‎onnxruntime/python/tools/transformers/convert_generation.py

Lines changed: 702 additions & 231 deletions
Large diffs are not rendered by default.

‎onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
import os
1111

1212
import torch
13-
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
13+
from benchmark_helper import (
14+
Precision,
15+
create_onnxruntime_session,
16+
prepare_environment,
17+
setup_logger,
18+
)
1419
from t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS, T5Helper
1520

1621
logger = logging.getLogger("")
@@ -26,7 +31,8 @@ def parse_arguments():
2631
required=False,
2732
default=PRETRAINED_T5_MODELS[0],
2833
type=str,
29-
help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models),
34+
help="Model path, or pretrained model name in the list: "
35+
+ ", ".join(pretrained_models),
3036
)
3137

3238
parser.add_argument(
@@ -63,7 +69,9 @@ def parse_arguments():
6369
)
6470
parser.set_defaults(optimize_onnx=False)
6571

66-
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
72+
parser.add_argument(
73+
"--use_gpu", required=False, action="store_true", help="use GPU for inference"
74+
)
6775
parser.set_defaults(use_gpu=False)
6876

6977
parser.add_argument(
@@ -79,7 +87,9 @@ def parse_arguments():
7987
parser.add_argument("--verbose", required=False, action="store_true")
8088
parser.set_defaults(verbose=False)
8189

82-
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
90+
parser.add_argument(
91+
"-e", "--use_external_data_format", required=False, action="store_true"
92+
)
8393
parser.set_defaults(use_external_data_format=False)
8494

8595
parser.add_argument(
@@ -108,14 +118,6 @@ def parse_arguments():
108118
)
109119
parser.set_defaults(disable_auto_mixed_precision=False)
110120

111-
parser.add_argument(
112-
"--separate_encoder_and_decoder_init",
113-
required=False,
114-
action="store_true",
115-
help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.",
116-
)
117-
parser.set_defaults(separate_encoder_and_decoder_init=False)
118-
119121
parser.add_argument(
120122
"--use_int64_inputs",
121123
required=False,
@@ -131,6 +133,14 @@ def parse_arguments():
131133
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
132134
)
133135

136+
parser.add_argument(
137+
"--encode_decoder_init",
138+
required=False,
139+
action="store_true",
140+
help="Combine encoder and decoder kv cache initialization into one model.",
141+
)
142+
parser.set_defaults(encode_decoder_init=False)
143+
134144
args = parser.parse_args()
135145

136146
return args
@@ -146,17 +156,22 @@ def export_onnx_models(
146156
precision,
147157
verbose,
148158
use_decoder_start_token: bool = False,
149-
merge_encoder_and_decoder_init: bool = True,
150159
overwrite: bool = False,
151160
disable_auto_mixed_precision: bool = False,
152161
use_int32_inputs: bool = True,
153162
model_type: str = "t5",
154163
state_dict_path: str = "",
164+
encode_decoder_init: bool = False,
155165
):
156166
device = torch.device("cuda:0" if use_gpu else "cpu")
157167

158168
models = T5Helper.load_model(
159-
model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type, state_dict_path
169+
model_name_or_path,
170+
cache_dir,
171+
device,
172+
model_type,
173+
state_dict_path,
174+
encode_decoder_init=encode_decoder_init,
160175
)
161176
config = models["decoder"].config
162177

@@ -220,11 +235,17 @@ def export_onnx_models(
220235
ort_session = create_onnxruntime_session(
221236
output_path,
222237
use_gpu=use_gpu,
223-
provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"],
238+
provider=(
239+
["CUDAExecutionProvider", "CPUExecutionProvider"]
240+
if use_gpu
241+
else ["CPUExecutionProvider"]
242+
),
224243
)
225244

226245
with torch.no_grad():
227-
max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs)
246+
max_diff = T5Helper.verify_onnx(
247+
model, ort_session, device, use_int32_inputs
248+
)
228249
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
229250
if max_diff > 1e-4:
230251
logger.warning("PyTorch and OnnxRuntime results are NOT close")
@@ -242,7 +263,11 @@ def main():
242263
logger.info(f"Arguments:{args}")
243264

244265
cache_dir = args.cache_dir
245-
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
266+
output_dir = (
267+
args.output
268+
if not args.output.endswith(".onnx")
269+
else os.path.dirname(args.output)
270+
)
246271
prepare_environment(cache_dir, output_dir, args.use_gpu)
247272

248273
if args.precision != Precision.FLOAT32:
@@ -264,11 +289,11 @@ def main():
264289
args.precision,
265290
args.verbose,
266291
args.use_decoder_start_token,
267-
not args.separate_encoder_and_decoder_init,
268292
args.overwrite,
269293
args.disable_auto_mixed_precision,
270294
not args.use_int64_inputs,
271295
args.model_type,
296+
encode_decoder_init=args.encode_decoder_init,
272297
)
273298

274299
logger.info(f"Done! Outputs: {output_paths}")
Lines changed: 11 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
11
# -------------------------------------------------------------------------
2-
# Copyright (c) Microsoft Corporation. All rights reserved.
3-
# Licensed under the MIT License. See License.txt in the project root for
4-
# license information.
5-
# --------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# -------------------------------------------------------------------------
65

76
import logging
8-
import os
97
import random
10-
import tempfile
11-
from pathlib import Path
128

13-
import numpy
14-
import onnx
159
import torch
16-
from onnx_model import OnnxModel
17-
from torch_onnx_export_helper import torch_onnx_export
1810
from transformers import MT5Config, T5Config
1911

20-
from onnxruntime import InferenceSession
21-
2212
logger = logging.getLogger(__name__)
2313

2414

@@ -41,7 +31,11 @@ def __init__(self, input_ids, attention_mask):
4131

4232
@staticmethod
4333
def create_dummy(
44-
batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False
34+
batch_size: int,
35+
sequence_length: int,
36+
vocab_size: int,
37+
device: torch.device,
38+
use_int32_inputs: bool = False,
4539
): # -> T5EncoderInputs
4640
"""Create dummy inputs for T5 encoder.
4741
@@ -64,7 +58,9 @@ def create_dummy(
6458
device=device,
6559
)
6660

67-
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
61+
attention_mask = torch.ones(
62+
[batch_size, sequence_length], dtype=dtype, device=device
63+
)
6864
if sequence_length >= 2:
6965
for i in range(batch_size):
7066
padding_position = random.randint(0, sequence_length - 1)
@@ -74,97 +70,3 @@ def create_dummy(
7470
def to_list(self) -> list:
7571
input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
7672
return input_list
77-
78-
79-
class T5EncoderHelper:
80-
@staticmethod
81-
def export_onnx(
82-
encoder: T5Encoder,
83-
device: torch.device,
84-
onnx_model_path: str,
85-
verbose: bool = True,
86-
use_external_data_format: bool = False,
87-
use_int32_inputs: bool = False,
88-
):
89-
"""Export encoder to ONNX
90-
91-
Args:
92-
encoder (T5Encoder): encoder object
93-
device (torch.device): device of encoder object
94-
onnx_model_path (str): onnx path
95-
verbose (bool, optional): print verbose information. Defaults to True.
96-
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
97-
"""
98-
config = encoder.config
99-
encoder_inputs = T5EncoderInputs.create_dummy(
100-
batch_size=2,
101-
sequence_length=4,
102-
vocab_size=config.vocab_size,
103-
device=device,
104-
use_int32_inputs=use_int32_inputs,
105-
)
106-
107-
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
108-
109-
with tempfile.TemporaryDirectory() as tmp_dir_name:
110-
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
111-
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
112-
torch_onnx_export(
113-
encoder,
114-
args=tuple(encoder_inputs.to_list()),
115-
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
116-
export_params=True,
117-
input_names=["input_ids", "attention_mask"],
118-
output_names=["hidden_states"],
119-
dynamic_axes={
120-
"input_ids": {0: "batch_size", 1: "sequence_length"},
121-
"attention_mask": {0: "batch_size", 1: "sequence_length"},
122-
"hidden_states": {0: "batch_size", 1: "sequence_length"},
123-
},
124-
opset_version=12,
125-
do_constant_folding=True,
126-
use_external_data_format=use_external_data_format,
127-
verbose=verbose,
128-
)
129-
130-
if use_external_data_format:
131-
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
132-
OnnxModel.save(
133-
model,
134-
onnx_model_path,
135-
save_as_external_data=True,
136-
all_tensors_to_one_file=True,
137-
)
138-
139-
@staticmethod
140-
def onnxruntime_inference(ort_session, inputs: T5EncoderInputs):
141-
"""Run inference of ONNX model."""
142-
ort_inputs = {
143-
"input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
144-
"attention_mask": numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()),
145-
}
146-
147-
return ort_session.run(None, ort_inputs)
148-
149-
@staticmethod
150-
def verify_onnx(
151-
model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
152-
):
153-
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
154-
inputs = T5EncoderInputs.create_dummy(
155-
batch_size=4,
156-
sequence_length=11,
157-
vocab_size=model.config.vocab_size,
158-
device=device,
159-
use_int32_inputs=use_int32_inputs,
160-
)
161-
input_list = inputs.to_list()
162-
torch_outputs = model(*input_list)
163-
164-
ort_outputs = T5EncoderHelper.onnxruntime_inference(ort_session, inputs)
165-
166-
max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
167-
168-
logger.info(f"max_diff={max_diff}")
169-
170-
return max_diff

‎onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py

Lines changed: 141 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# -------------------------------------------------------------------------
2-
# Copyright (c) Microsoft Corporation. All rights reserved.
3-
# Licensed under the MIT License. See License.txt in the project root for
4-
# license information.
5-
# --------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# -------------------------------------------------------------------------
65

76
import logging
87
import os
@@ -34,27 +33,42 @@ def __init__(
3433
lm_head: torch.nn.Module,
3534
config: T5Config | MT5Config,
3635
decoder_start_token_id: int | None = None,
36+
output_cross_only: bool = False,
3737
):
3838
super().__init__()
3939
self.config = config
4040
self.t5_encoder = T5Encoder(encoder, config)
41-
self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id)
41+
self.t5_decoder_init = T5DecoderInit(
42+
decoder, lm_head, config, decoder_start_token_id
43+
)
44+
self.output_cross_only = output_cross_only
4245

4346
def forward(
4447
self,
4548
encoder_input_ids: torch.Tensor,
4649
encoder_attention_mask: torch.Tensor,
4750
decoder_input_ids: torch.Tensor = None,
4851
):
49-
encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask)
50-
lm_logits, past_self, past_cross = self.t5_decoder_init(
51-
decoder_input_ids, encoder_attention_mask, encoder_hidden_states
52+
encoder_hidden_states: torch.FloatTensor = self.t5_encoder(
53+
encoder_input_ids, encoder_attention_mask
5254
)
53-
return lm_logits, encoder_hidden_states, past_self, past_cross
55+
56+
if self.output_cross_only:
57+
lm_logits, past_self, past_cross = self.t5_decoder_init(
58+
decoder_input_ids, encoder_attention_mask, encoder_hidden_states
59+
)
60+
return past_cross
61+
else:
62+
lm_logits, past_self, past_cross = self.t5_decoder_init(
63+
decoder_input_ids, encoder_attention_mask, encoder_hidden_states
64+
)
65+
return lm_logits, encoder_hidden_states, past_self, past_cross
5466

5567

5668
class T5EncoderDecoderInitInputs:
57-
def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
69+
def __init__(
70+
self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None
71+
):
5872
self.encoder_input_ids: torch.LongTensor = encoder_input_ids
5973
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
6074
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
@@ -78,9 +92,14 @@ def create_dummy(
7892
decoder_input_ids = None
7993
if use_decoder_input_ids:
8094
dtype = torch.int32 if use_int32_inputs else torch.int64
81-
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
95+
decoder_input_ids = (
96+
torch.ones((batch_size, 1), dtype=dtype, device=device)
97+
* config.decoder_start_token_id
98+
)
8299

83-
return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
100+
return T5EncoderDecoderInitInputs(
101+
encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids
102+
)
84103

85104
def to_list(self) -> list:
86105
input_list = [self.encoder_input_ids, self.encoder_attention_mask]
@@ -108,9 +127,14 @@ def export_onnx(
108127
onnx_model_path (str): onnx path
109128
verbose (bool, optional): print verbose information. Defaults to True.
110129
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
130+
use_int32_inputs (bool, optional): use int32 instead of int64 for integer inputs. Defaults to False.
111131
"""
112132
assert isinstance(model, T5EncoderDecoderInit)
113133

134+
# Do not exclude decoder in torch onnx export so that cross can show up.
135+
output_cross_only = model.output_cross_only
136+
model.output_cross_only = False
137+
114138
inputs = T5EncoderDecoderInitInputs.create_dummy(
115139
model.config,
116140
batch_size=2,
@@ -121,7 +145,9 @@ def export_onnx(
121145
)
122146
input_list = inputs.to_list()
123147

124-
present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)
148+
present_names = PastKeyValuesHelper.get_past_names(
149+
model.config.num_decoder_layers, present=True
150+
)
125151

126152
output_names = ["logits", "encoder_hidden_states", *present_names]
127153

@@ -185,7 +211,9 @@ def export_onnx(
185211
}
186212

187213
with tempfile.TemporaryDirectory() as tmp_dir_name:
188-
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
214+
temp_onnx_model_path = os.path.join(
215+
tmp_dir_name, "encoder_decoder_init.onnx"
216+
)
189217
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
190218
torch_onnx_export(
191219
model,
@@ -201,6 +229,9 @@ def export_onnx(
201229
verbose=verbose,
202230
)
203231

232+
# Restore output_cross_only setting.
233+
model.output_cross_only = output_cross_only
234+
204235
# Workaround as mentioned earlier: change numeric dim_param to dim_value
205236
model = onnx.load(temp_onnx_model_path)
206237
for tensor in model.graph.output:
@@ -215,6 +246,54 @@ def export_onnx(
215246
dim_proto.Clear()
216247
dim_proto.dim_value = dim_value
217248

249+
if output_cross_only:
250+
# Rewrite onnx graph to only keep present_[key|value]_cross_* outputs.
251+
onnx_model = OnnxModel(model)
252+
output_name_to_node = onnx_model.output_name_to_node()
253+
254+
for output in model.graph.output:
255+
if "cross" in output.name:
256+
assert output.name in output_name_to_node
257+
258+
transpose_node = output_name_to_node[output.name]
259+
assert transpose_node and transpose_node.op_type == "Transpose"
260+
261+
permutation = OnnxModel.get_node_attribute(
262+
transpose_node, "perm"
263+
)
264+
assert isinstance(permutation, list)
265+
assert permutation == [0, 2, 1, 3]
266+
267+
matched_nodes = onnx_model.match_parent_path(
268+
transpose_node,
269+
["Reshape", "MatMul"],
270+
[0, 0],
271+
output_name_to_node,
272+
)
273+
assert matched_nodes is not None
274+
275+
reshape_node, matmul_node = matched_nodes
276+
assert "encoder_hidden_states" in matmul_node.input
277+
278+
if not onnx_model.get_initializer("cross_reshape_shape"):
279+
shape_tensor = onnx.helper.make_tensor(
280+
name="cross_reshape_shape",
281+
data_type=onnx.TensorProto.INT64,
282+
dims=[4],
283+
vals=[0, 0, int(num_heads), int(head_size)],
284+
raw=False,
285+
)
286+
onnx_model.add_initializer(shape_tensor)
287+
288+
reshape_node.input[1] = "cross_reshape_shape"
289+
290+
cross_outputs = [
291+
output.name
292+
for output in model.graph.output
293+
if "cross" in output.name
294+
]
295+
onnx_model.prune_graph(cross_outputs, allow_remove_graph_inputs=True)
296+
218297
OnnxModel.save(
219298
model,
220299
onnx_model_path,
@@ -228,11 +307,17 @@ def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
228307
logger.debug("start onnxruntime_inference")
229308

230309
ort_inputs = {
231-
"encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
232-
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
310+
"encoder_input_ids": numpy.ascontiguousarray(
311+
inputs.encoder_input_ids.cpu().numpy()
312+
),
313+
"encoder_attention_mask": numpy.ascontiguousarray(
314+
inputs.encoder_attention_mask.cpu().numpy()
315+
),
233316
}
234317
if inputs.decoder_input_ids is not None:
235-
ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
318+
ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(
319+
inputs.decoder_input_ids.cpu().numpy()
320+
)
236321

237322
ort_outputs = ort_session.run(None, ort_inputs)
238323
return ort_outputs
@@ -261,35 +346,57 @@ def verify_onnx(
261346
use_int32_inputs=use_int32_inputs,
262347
)
263348

264-
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
349+
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(
350+
ort_session, inputs
351+
)
265352

266353
# Run inference of PyTorch model
267354
input_list = inputs.to_list()
268355
torch_outputs = model(*input_list)
269356

270357
num_decoder_layers = model.config.num_decoder_layers
271358

272-
assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
273-
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
274-
logger.debug(f"logits max_diff={max_diff}")
275-
max_diff_all = max_diff
276-
277-
assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
278-
max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
279-
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
280-
max_diff_all = max(max_diff_all, max_diff)
281-
282-
for i in range(2 * num_decoder_layers):
283-
max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
284-
logger.debug(f"self attention past state {i} max_diff={max_diff}")
359+
if not model.output_cross_only:
360+
assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
361+
max_diff = numpy.amax(
362+
numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0])
363+
)
364+
logger.debug(f"logits max_diff={max_diff}")
365+
max_diff_all = max_diff
285366

286-
for i in range(2 * num_decoder_layers):
367+
assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
287368
max_diff = numpy.amax(
288-
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
369+
numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1])
289370
)
290-
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
371+
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
291372
max_diff_all = max(max_diff_all, max_diff)
292373

374+
for i in range(2 * num_decoder_layers):
375+
max_diff = numpy.amax(
376+
numpy.abs(
377+
torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]
378+
)
379+
)
380+
logger.debug(f"self attention past state {i} max_diff={max_diff}")
381+
382+
for i in range(2 * num_decoder_layers):
383+
max_diff = numpy.amax(
384+
numpy.abs(
385+
torch_outputs[3][i].cpu().numpy()
386+
- ort_outputs[2 + 2 * num_decoder_layers + i]
387+
)
388+
)
389+
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
390+
max_diff_all = max(max_diff_all, max_diff)
391+
else:
392+
max_diff_all = -float("inf")
393+
for i in range(2 * num_decoder_layers):
394+
max_diff = numpy.amax(
395+
numpy.abs(torch_outputs[i].cpu().numpy() - ort_outputs[i])
396+
)
397+
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
398+
max_diff_all = max(max_diff_all, max_diff)
399+
293400
test_cases_max_diff.append(max_diff_all)
294401
logger.info(
295402
f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"

‎onnxruntime/python/tools/transformers/models/t5/t5_helper.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# -------------------------------------------------------------------------
2-
# Copyright (c) Microsoft Corporation. All rights reserved.
3-
# Licensed under the MIT License. See License.txt in the project root for
4-
# license information.
5-
# --------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# -------------------------------------------------------------------------
65

76
import logging
87
import os
@@ -12,8 +11,7 @@
1211
from float16 import float_to_float16_max_diff
1312
from onnx_model import OnnxModel
1413
from optimizer import optimize_model
15-
from t5_decoder import T5Decoder, T5DecoderHelper, T5DecoderInit
16-
from t5_encoder import T5Encoder, T5EncoderHelper
14+
from t5_decoder import T5Decoder, T5DecoderHelper
1715
from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
1816
from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
1917

@@ -22,7 +20,13 @@
2220
logger = logging.getLogger(__name__)
2321

2422
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
25-
PRETRAINED_MT5_MODELS = ["google/mt5-small", "google/mt5-base", "google/mt5-large", "google/mt5-xl", "google/mt5-xxl"]
23+
PRETRAINED_MT5_MODELS = [
24+
"google/mt5-small",
25+
"google/mt5-base",
26+
"google/mt5-large",
27+
"google/mt5-xl",
28+
"google/mt5-xxl",
29+
]
2630

2731

2832
class T5Helper:
@@ -60,25 +64,30 @@ def load_model(
6064
model_name_or_path: str,
6165
cache_dir: str,
6266
device: torch.device,
63-
merge_encoder_and_decoder_init: bool = True,
6467
model_type: str = "t5",
6568
state_dict_path: str = "",
69+
encode_decoder_init: bool = False,
6670
) -> dict[str, torch.nn.Module]:
6771
"""Load model given a pretrained name or path, then build models for ONNX conversion.
6872
6973
Args:
7074
model_name_or_path (str): pretrained model name or path
7175
cache_dir (str): cache directory
7276
device (torch.device): device to run the model
73-
merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
74-
is_mt5 (bool, optional): whether the model is MT5 instead of T5
77+
model_type (str, optional): model type "t5" or "mt5"
78+
state_dict_path(str, optional): state dictionary path
79+
encode_decoder_init (bool, optional): combine encoder and decoder kv cache initialization into one model.
7580
Returns:
7681
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
7782
"""
7883
if model_type == "t5":
79-
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
84+
model = T5ForConditionalGeneration.from_pretrained(
85+
model_name_or_path, cache_dir=cache_dir
86+
)
8087
elif model_type == "mt5":
81-
model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
88+
model = MT5ForConditionalGeneration.from_pretrained(
89+
model_name_or_path, cache_dir=cache_dir
90+
)
8291
else:
8392
raise ValueError("only support mode_type=t5 or mt5")
8493

@@ -88,46 +97,29 @@ def load_model(
8897
decoder = T5Decoder(model.decoder, model.lm_head, model.config)
8998
decoder.eval().to(device)
9099

91-
if merge_encoder_and_decoder_init:
92-
encoder_decoder_init = T5EncoderDecoderInit(
93-
model.encoder,
94-
model.decoder,
95-
model.lm_head,
96-
model.config,
97-
decoder_start_token_id=None,
98-
)
99-
return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder}
100-
else:
101-
encoder = T5Encoder(model.encoder, model.config)
102-
encoder.eval().to(device)
103-
decoder_init = T5DecoderInit(model.decoder, model.lm_head, model.config)
104-
decoder_init.eval().to(device)
105-
return {
106-
"encoder": encoder,
107-
"decoder": decoder,
108-
"decoder_init": decoder_init,
109-
}
100+
encoder_decoder_init = T5EncoderDecoderInit(
101+
model.encoder,
102+
model.decoder,
103+
model.lm_head,
104+
model.config,
105+
decoder_start_token_id=None,
106+
output_cross_only=not encode_decoder_init,
107+
)
108+
109+
encoder_name = "encoder_decoder_init" if encode_decoder_init else "encoder"
110+
return {encoder_name: encoder_decoder_init, "decoder": decoder}
110111

111112
@staticmethod
112113
def export_onnx(
113-
model: T5Encoder | T5Decoder | T5DecoderInit | T5EncoderDecoderInit,
114+
model: T5Decoder | T5EncoderDecoderInit,
114115
device: torch.device,
115116
onnx_model_path: str,
116117
verbose: bool = True,
117118
use_external_data_format: bool = False,
118119
use_decoder_input_ids: bool = True,
119120
use_int32_inputs: bool = False,
120121
):
121-
if isinstance(model, T5Encoder):
122-
T5EncoderHelper.export_onnx(
123-
model,
124-
device,
125-
onnx_model_path,
126-
verbose,
127-
use_external_data_format,
128-
use_int32_inputs,
129-
)
130-
elif isinstance(model, T5EncoderDecoderInit):
122+
if isinstance(model, T5EncoderDecoderInit):
131123
T5EncoderDecoderInitHelper.export_onnx(
132124
model,
133125
device,
@@ -191,10 +183,14 @@ def auto_mixed_precision(
191183
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
192184
# we can deduce that the weights are stored in float16 precision.
193185
max_diff = float_to_float16_max_diff(initializer)
194-
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
186+
logger.debug(
187+
f"max diff of converting weights in last MatMul node {node.name}: {max_diff}"
188+
)
195189
is_weight_fp16_precision = max_diff < 1e-6
196190
else:
197-
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
191+
logger.warning(
192+
f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}"
193+
)
198194

199195
keep_io_types = []
200196
node_block_list = []
@@ -252,20 +248,21 @@ def optimize_onnx(
252248
else:
253249
m.convert_model_float32_to_float16(cast_input_output=False)
254250

255-
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
251+
m.save_model_to_file(
252+
optimized_model_path, use_external_data_format, all_tensors_to_one_file=True
253+
)
256254

257255
@staticmethod
258256
def verify_onnx(
259-
model: T5Encoder | T5Decoder | T5DecoderInit | T5EncoderDecoderInit,
257+
model: T5Decoder | T5EncoderDecoderInit,
260258
ort_session: InferenceSession,
261259
device: torch.device,
262260
use_int32_inputs: bool,
263261
):
264262
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
265-
if isinstance(model, T5Encoder):
266-
return T5EncoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
267-
268263
if isinstance(model, T5EncoderDecoderInit):
269-
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
264+
return T5EncoderDecoderInitHelper.verify_onnx(
265+
model, ort_session, device, use_int32_inputs
266+
)
270267

271268
return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)

0 commit comments

Comments
 (0)
Please sign in to comment.