Skip to content

Commit 5ce3639

Browse files
committedMar 10, 2025
new design of t5 onnx
1 parent 49328fe commit 5ce3639

12 files changed

+1145
-580
lines changed
 

‎onnxruntime/contrib_ops/cpu/transformers/beam_search.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,19 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
139139
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
140140
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
141141

142-
if (parameters_->decoder_start_token_id < 0) {
143-
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
144-
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
142+
if (t5_encoder_subgraph_->HasLogitsOutput()) {
143+
// New format requires start token id.
144+
ORT_ENFORCE(parameters_->decoder_start_token_id >= 0);
145145
} else {
146-
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
147-
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
146+
if (parameters_->decoder_start_token_id < 0) {
147+
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
148+
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
149+
} else {
150+
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
151+
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
152+
}
148153
}
154+
149155
} else if (attribute_name == "decoder") {
150156
ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
151157
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");

‎onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ class BeamSearchT5 : public BeamSearchBase<T> {
5151
expand_buffer_int32_func_(expand_buffer_int32_func),
5252
expand_buffer_float_func_(expand_buffer_float_func),
5353
expand_buffer_float16_func_(expand_buffer_float16_func),
54-
create_beam_scorer_func_(create_beam_scorer_func) {}
54+
create_beam_scorer_func_(create_beam_scorer_func) {
55+
// When decoder uses encoder_hidden_state, make sure the encoder outputs it.
56+
if (decoder_subgraph_.UseEncoderHiddenState()) {
57+
ORT_ENFORCE(encoder_subgraph_.subgraph_output_names[1] == "encoder_hidden_states");
58+
}
59+
ORT_ENFORCE(encoder_subgraph_.num_layers == decoder_subgraph_.num_layers);
60+
}
5561

5662
#ifdef USE_CUDA
5763
Status InitializeCuda(
@@ -160,7 +166,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
160166
this->create_encoder_inputs_func_,
161167
this->add_to_feeds_func_,
162168
buffer,
163-
decoder_input_ids,
169+
decoder_input_ids, // new format does not use decoder_input_ids in encoder, it is still initialized here when decoder_start_token_id >= 0.
164170
this->ort_stream_));
165171

166172
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
@@ -233,15 +239,20 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
233239

234240
std::vector<OrtValue> decoder_fetches;
235241

236-
if (current_length + 1 < parameters->max_length) {
242+
// When encoder outputs logits (in old format), we need get the next token from logits.
243+
if (current_length + 1 < parameters->max_length && encoder_subgraph_.HasLogitsOutput()) {
237244
++iteration_counter;
238-
ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0],
245+
const OrtValue& logits = encoder_fetches[0];
246+
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
239247
beam_next_tokens,
240248
beam_state,
241249
cpu_state,
242250
iteration_counter));
243251
++current_length; // Increase sequence length after a new token is generated.
252+
}
244253

254+
// Generate inputs for next decoder subgraph call.
255+
if (current_length < parameters->max_length) {
245256
ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(this->cpu_allocator_,
246257
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
247258
this->implicit_inputs_,
@@ -262,6 +273,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
262273
this->cuda_device_prop_ != nullptr));
263274

264275
if (decoder_subgraph_.past_present_share_buffer_) {
276+
// Configure buffer sharing of past and present kv cache.
265277
decoder_fetches.reserve(static_cast<size_t>(decoder_subgraph_.GetFirstPresentOutputIndex()) +
266278
2 * static_cast<size_t>(decoder_subgraph_.num_layers));
267279
decoder_fetches.resize(decoder_subgraph_.GetFirstPresentOutputIndex(), OrtValue());

‎onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@ Subgraph::Subgraph(
3636
auto& subgraph_inputs = subgraph.GetInputs();
3737
auto& subgraph_outputs = subgraph.GetOutputs();
3838

39-
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
40-
// outputs: logits, present_0, present_1, ...
4139
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
4240
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
4341

44-
// CheckSubgraph will verify inputs and outputs later.
4542
subgraph_input_names.reserve(num_subgraph_inputs);
4643
for (int i = 0; i < num_subgraph_inputs; ++i) {
4744
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
@@ -68,10 +65,9 @@ Status Subgraph::Setup(const SessionState& session_state,
6865
InlinedVector<std::string_view> feed_names;
6966
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
7067

71-
// Use the first output (logits) to find device location.
68+
// Use the first output to find device location.
7269
const OrtDevice& default_location = utils::FindDeviceForValue(subgraph_session_state, subgraph_output_names[0]);
7370

74-
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
7571
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
7672

7773
const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
@@ -174,13 +170,15 @@ Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shap
174170
}
175171

176172
// Logits shape is like (batch_size, seq_len, vocabulary_size)
177-
ORT_RETURN_IF(logits_shape->dim_size() != 3,
178-
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
173+
if (logits_shape != nullptr) {
174+
ORT_RETURN_IF(logits_shape->dim_size() != 3,
175+
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
179176

180-
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
181-
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
177+
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
178+
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
182179

183-
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
180+
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
181+
}
184182

185183
return Status::OK();
186184
}

‎onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc

Lines changed: 83 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,23 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
141141
}
142142

143143
// Create inputs for decoder from the following data sources:
144-
// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
145-
// encoder fetches: logits,
146-
// encoder_hidden_states,
147-
// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
148-
// decoder_feeds: input_ids,
149-
// encoder_attention_mask,
150-
// encoder_hidden_states,
151-
// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
144+
// New format:
145+
// encoder feeds: encoder_input_ids, encoder_attention_mask
146+
// encoder fetches: present_key_cross_0, present_value_cross_0, ...
147+
// decoder_feeds: input_ids, encoder_attention_mask,
148+
// present_key_self_0, present_value_self_0, ...,
149+
// present_key_cross_0, present_value_cross_0, ...
150+
151+
// Old format:
152+
// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
153+
// encoder fetches: logits, encoder_hidden_states,
154+
// present_key_self_0, present_value_self_0, ...,
155+
// present_key_cross_0, present_value_cross_0, ...
156+
// decoder_feeds: input_ids, encoder_input_ids (optiona), encoder_attention_mask, encoder_hidden_states,
157+
// present_key_self_0, present_value_self_0, ...,
158+
// present_key_cross_0, present_value_cross_0, ...
159+
// past_seq_len (optional), num_beams (optional), cache_indirection (optional)
160+
152161
Status T5DecoderSubgraph::CreateInitialFeeds(
153162
AllocatorPtr cpu_allocator,
154163
gsl::span<const int32_t> beam_next_tokens,
@@ -173,33 +182,30 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
173182
// Allocate subgraph inputs from same device as inputs of encoder subgraph.
174183
AllocatorPtr allocator = session_state_->GetAllocator(encoder_feeds[0].Get<Tensor>().Location());
175184

185+
int batch_beam_size = static_cast<int>(encoder_fetches[0].Get<Tensor>().Shape()[0]) * num_beam;
186+
176187
// Copy beam next tokens in CPU to input_ids in provider device (CPU for CPU EP, or GPU for CUDA EP).
177-
int batch_beam_size = static_cast<int>(beam_next_tokens.size());
178188
int sequence_length = !use_sequence_as_input_ids ? 1 : cur_len;
179189
int64_t dims[] = {batch_beam_size, sequence_length};
180190
TensorShape input_ids_shape(&dims[0], 2);
181191
OrtValue input_ids;
182192
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);
183-
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
184-
AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
185-
size_t total_size = static_cast<size_t>(cur_len) * static_cast<size_t>(batch_beam_size);
186-
size_t total_size_bytes = total_size * sizeof(int);
187-
auto seq_copy = IAllocator::MakeUniquePtr<int>(buffer_allocator, total_size_bytes, false, stream);
188-
int* seq_copy_ptr = seq_copy.get();
189-
190-
if (!use_sequence_as_input_ids_) {
193+
194+
// Prepare data for input_ids.
195+
if (!use_sequence_as_input_ids_) { // use next tokens for input_ids. This is for Whisper model.
191196
ORT_RETURN_IF_ERROR(device_copy_int32_func(
192197
input_ids.GetMutable<Tensor>()->MutableDataAsSpan<int32_t>(),
193198
beam_next_tokens,
194199
stream,
195200
DeviceCopyDirection::hostToDevice));
196201
} else {
202+
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
197203
if (use_cuda) {
198204
auto sequences_buffer = sequences.GetCurrentDeviceSequences();
199205
for (int i = 0; i < batch_beam_size; i++) {
200-
size_t batch_beam_stride = static_cast<size_t>(i) * static_cast<size_t>(sequences.GetMaxLength());
206+
size_t offset = static_cast<size_t>(i) * static_cast<size_t>(sequences.GetMaxLength());
201207
int seq_size = sequences.GetSequenceLength();
202-
gsl::span<const int32_t> sequence = sequences_buffer.subspan(batch_beam_stride, seq_size);
208+
gsl::span<const int32_t> sequence = sequences_buffer.subspan(offset, seq_size);
203209
gsl::span<int> temp_input(input_ids_data + static_cast<ptrdiff_t>(i) * seq_size, seq_size);
204210
ORT_RETURN_IF_ERROR(device_copy_int32_func(
205211
temp_input,
@@ -208,6 +214,13 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
208214
DeviceCopyDirection::deviceToDevice));
209215
}
210216
} else {
217+
size_t total_size = static_cast<size_t>(cur_len) * static_cast<size_t>(batch_beam_size);
218+
size_t total_size_bytes = total_size * sizeof(int);
219+
AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
220+
// TODO: not need extra buffer. Copy directly to input_ids_data instead like the user_cuda above.
221+
auto seq_copy = IAllocator::MakeUniquePtr<int>(buffer_allocator, total_size_bytes, false, stream);
222+
int* seq_copy_ptr = seq_copy.get();
223+
211224
const size_t cur_len_bytes = cur_len * sizeof(int);
212225
for (int i = 0; i < batch_beam_size; i++) {
213226
gsl::span<const int32_t> sequence = sequences.GetSequence(i);
@@ -227,9 +240,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
227240

228241
// The ordering is the same as used in Setup.
229242
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
243+
244+
// input 0: input_ids
230245
decoder_feeds.push_back(input_ids);
231246

232-
if (has_encoder_input_ids_) {
247+
if (has_encoder_input_ids_) { // encoder_input_ids is optional
233248
// The encoder_input_ids is copied from the first input of encoder.
234249
OrtValue expanded_encoder_input_ids;
235250
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
@@ -251,70 +266,65 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
251266
expanded_decoder_attention_masks,
252267
false,
253268
0 /*max_sequence_length*/));
254-
255269
decoder_feeds.push_back(expanded_decoder_attention_masks);
256270

257271
if (!past_present_share_buffer_) {
258272
past_present_share_buffer_max_seq_len = 0;
259273
}
260274

261-
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
262-
// of encoder.
263-
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
264-
// TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.
265-
// What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
266-
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) {
267-
if (j == 1) {
268-
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
269-
OrtValue expanded_hidden_states;
270-
if (is_output_float16_) {
271-
ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream,
272-
encoder_fetches[j],
273-
num_beam,
274-
allocator,
275-
expanded_hidden_states,
276-
false,
277-
0 /*max_sequence_length*/));
278-
} else {
279-
ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream,
280-
encoder_fetches[j],
281-
num_beam,
282-
allocator,
283-
expanded_hidden_states,
284-
false,
285-
0 /*max_sequence_length*/));
286-
}
287-
decoder_feeds.push_back(expanded_hidden_states);
288-
} else {
275+
// macro to expand encoder outputs and append to decoder feeds.
276+
#define ADD_DECODER_FEED(encoder_output, is_dynamic_kv_cache) \
277+
OrtValue expanded; \
278+
if (is_output_float16_) { \
279+
ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, encoder_output, num_beam, allocator, expanded, false, \
280+
is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0)); \
281+
} else { \
282+
ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, encoder_output, num_beam, allocator, expanded, false, \
283+
is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0)); \
284+
} \
285+
decoder_feeds.push_back(expanded);
286+
287+
// The encoder_hidden_states is copied from the second output of encoder.
288+
if (has_hidden_state_) {
289+
ADD_DECODER_FEED(encoder_fetches[1], false);
290+
}
291+
292+
// New format of encoder has only cross outputs.
293+
bool is_new_format = (static_cast<int>(encoder_fetches.size()) == 2 * num_layers);
294+
if (is_new_format) {
295+
for (int i = 0; i < 2 * num_layers; i++) {
296+
// cross shape is (batch_size, num_heads, encode_sequence_length, head_size)
297+
const TensorShape& cross_shape = encoder_fetches[0].Get<Tensor>().Shape();
298+
ORT_ENFORCE(cross_shape.NumDimensions() == 4);
299+
300+
// Shape for kv cache: (batch_size * num_beam, num_heads, max_seq_len, head_size)
301+
int64_t cache_dims[4] = {0};
302+
cross_shape.CopyDims(cache_dims, cross_shape.NumDimensions());
303+
cache_dims[0] *= num_beam;
304+
cache_dims[2] = past_present_share_buffer_max_seq_len;
305+
TensorShape expanded_shape(&cache_dims[0], cross_shape.NumDimensions());
306+
307+
MLDataType element_type = encoder_fetches[0].Get<Tensor>().DataType();
308+
OrtValue past;
309+
Tensor::InitOrtValue(element_type, expanded_shape, allocator, past);
310+
decoder_feeds.push_back(past);
311+
}
312+
313+
// Add cross inputs from encoder output.
314+
for (size_t j = 0; j < encoder_fetches.size(); j++) {
315+
ADD_DECODER_FEED(encoder_fetches[j], false);
316+
}
317+
} else {
318+
for (size_t j = 1 + has_hidden_state_; j < encoder_fetches.size(); j++) {
289319
// past key/value for cross attention does not need to be initialized with max_seq_len since they are static.
290-
bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast<size_t>(num_layers);
291-
292-
OrtValue expanded_cache;
293-
if (is_output_float16_) {
294-
ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream,
295-
encoder_fetches[j],
296-
num_beam,
297-
allocator,
298-
expanded_cache,
299-
false,
300-
use_max_seq_len ? past_present_share_buffer_max_seq_len : 0));
301-
} else {
302-
ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream,
303-
encoder_fetches[j],
304-
num_beam,
305-
allocator,
306-
expanded_cache,
307-
false,
308-
use_max_seq_len ? past_present_share_buffer_max_seq_len : 0));
309-
}
310-
decoder_feeds.push_back(expanded_cache);
320+
bool is_dynamic_kv_cache = (j - first_past_input_index_) < 2 * static_cast<size_t>(num_layers);
321+
ADD_DECODER_FEED(encoder_fetches[j], is_dynamic_kv_cache);
311322
}
312323
}
313324

314-
// TODO: This part shares the similar logic with CreateInitialFeeds() in subgraph_gpt.cc. We should refactor it.
315325
if (past_present_share_buffer_) {
316-
// Past sequence length feed
317-
ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, 1));
326+
// Past sequence length set to 0
327+
ORT_RETURN_IF_ERROR(AppendPastSequenceLength(decoder_feeds, cpu_allocator, is_new_format ? 0 : 1));
318328
// Add beam search specific inputs
319329
if (need_cache_indir) {
320330
const int64_t batch_size = static_cast<int64_t>(batch_beam_size / num_beam);

‎onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ class T5DecoderSubgraph : public Subgraph {
7272
return use_sequence_as_input_ids_;
7373
}
7474

75+
inline bool UseEncoderHiddenState() const {
76+
return has_hidden_state_;
77+
}
78+
7579
protected:
7680
int first_past_input_index_;
7781
int first_present_output_index_;

0 commit comments

Comments
 (0)