@@ -141,14 +141,23 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
141
141
}
142
142
143
143
// Create inputs for decoder from the following data sources:
144
- // encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
145
- // encoder fetches: logits,
146
- // encoder_hidden_states,
147
- // present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
148
- // decoder_feeds: input_ids,
149
- // encoder_attention_mask,
150
- // encoder_hidden_states,
151
- // present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
144
+ // New format:
145
+ // encoder feeds: encoder_input_ids, encoder_attention_mask
146
+ // encoder fetches: present_key_cross_0, present_value_cross_0, ...
147
+ // decoder_feeds: input_ids, encoder_attention_mask,
148
+ // present_key_self_0, present_value_self_0, ...,
149
+ // present_key_cross_0, present_value_cross_0, ...
150
+
151
+ // Old format:
152
+ // encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
153
+ // encoder fetches: logits, encoder_hidden_states,
154
+ // present_key_self_0, present_value_self_0, ...,
155
+ // present_key_cross_0, present_value_cross_0, ...
156
+ // decoder_feeds: input_ids, encoder_input_ids (optiona), encoder_attention_mask, encoder_hidden_states,
157
+ // present_key_self_0, present_value_self_0, ...,
158
+ // present_key_cross_0, present_value_cross_0, ...
159
+ // past_seq_len (optional), num_beams (optional), cache_indirection (optional)
160
+
152
161
Status T5DecoderSubgraph::CreateInitialFeeds (
153
162
AllocatorPtr cpu_allocator,
154
163
gsl::span<const int32_t > beam_next_tokens,
@@ -173,33 +182,30 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
173
182
// Allocate subgraph inputs from same device as inputs of encoder subgraph.
174
183
AllocatorPtr allocator = session_state_->GetAllocator (encoder_feeds[0 ].Get <Tensor>().Location ());
175
184
185
+ int batch_beam_size = static_cast <int >(encoder_fetches[0 ].Get <Tensor>().Shape ()[0 ]) * num_beam;
186
+
176
187
// Copy beam next tokens in CPU to input_ids in provider device (CPU for CPU EP, or GPU for CUDA EP).
177
- int batch_beam_size = static_cast <int >(beam_next_tokens.size ());
178
188
int sequence_length = !use_sequence_as_input_ids ? 1 : cur_len;
179
189
int64_t dims[] = {batch_beam_size, sequence_length};
180
190
TensorShape input_ids_shape (&dims[0 ], 2 );
181
191
OrtValue input_ids;
182
192
Tensor::InitOrtValue (DataTypeImpl::GetType<int32_t >(), input_ids_shape, allocator, input_ids);
183
- int32_t * input_ids_data = input_ids.GetMutable <Tensor>()->MutableData <int32_t >();
184
- AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
185
- size_t total_size = static_cast <size_t >(cur_len) * static_cast <size_t >(batch_beam_size);
186
- size_t total_size_bytes = total_size * sizeof (int );
187
- auto seq_copy = IAllocator::MakeUniquePtr<int >(buffer_allocator, total_size_bytes, false , stream);
188
- int * seq_copy_ptr = seq_copy.get ();
189
-
190
- if (!use_sequence_as_input_ids_) {
193
+
194
+ // Prepare data for input_ids.
195
+ if (!use_sequence_as_input_ids_) { // use next tokens for input_ids. This is for Whisper model.
191
196
ORT_RETURN_IF_ERROR (device_copy_int32_func (
192
197
input_ids.GetMutable <Tensor>()->MutableDataAsSpan <int32_t >(),
193
198
beam_next_tokens,
194
199
stream,
195
200
DeviceCopyDirection::hostToDevice));
196
201
} else {
202
+ int32_t * input_ids_data = input_ids.GetMutable <Tensor>()->MutableData <int32_t >();
197
203
if (use_cuda) {
198
204
auto sequences_buffer = sequences.GetCurrentDeviceSequences ();
199
205
for (int i = 0 ; i < batch_beam_size; i++) {
200
- size_t batch_beam_stride = static_cast <size_t >(i) * static_cast <size_t >(sequences.GetMaxLength ());
206
+ size_t offset = static_cast <size_t >(i) * static_cast <size_t >(sequences.GetMaxLength ());
201
207
int seq_size = sequences.GetSequenceLength ();
202
- gsl::span<const int32_t > sequence = sequences_buffer.subspan (batch_beam_stride , seq_size);
208
+ gsl::span<const int32_t > sequence = sequences_buffer.subspan (offset , seq_size);
203
209
gsl::span<int > temp_input (input_ids_data + static_cast <ptrdiff_t >(i) * seq_size, seq_size);
204
210
ORT_RETURN_IF_ERROR (device_copy_int32_func (
205
211
temp_input,
@@ -208,6 +214,13 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
208
214
DeviceCopyDirection::deviceToDevice));
209
215
}
210
216
} else {
217
+ size_t total_size = static_cast <size_t >(cur_len) * static_cast <size_t >(batch_beam_size);
218
+ size_t total_size_bytes = total_size * sizeof (int );
219
+ AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
220
+ // TODO: not need extra buffer. Copy directly to input_ids_data instead like the user_cuda above.
221
+ auto seq_copy = IAllocator::MakeUniquePtr<int >(buffer_allocator, total_size_bytes, false , stream);
222
+ int * seq_copy_ptr = seq_copy.get ();
223
+
211
224
const size_t cur_len_bytes = cur_len * sizeof (int );
212
225
for (int i = 0 ; i < batch_beam_size; i++) {
213
226
gsl::span<const int32_t > sequence = sequences.GetSequence (i);
@@ -227,9 +240,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
227
240
228
241
// The ordering is the same as used in Setup.
229
242
decoder_feeds.reserve (static_cast <size_t >(num_subgraph_inputs) + static_cast <size_t >(num_implicit_inputs));
243
+
244
+ // input 0: input_ids
230
245
decoder_feeds.push_back (input_ids);
231
246
232
- if (has_encoder_input_ids_) {
247
+ if (has_encoder_input_ids_) { // encoder_input_ids is optional
233
248
// The encoder_input_ids is copied from the first input of encoder.
234
249
OrtValue expanded_encoder_input_ids;
235
250
ORT_RETURN_IF_ERROR (expand_buffer_int32_func (stream,
@@ -251,70 +266,65 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
251
266
expanded_decoder_attention_masks,
252
267
false ,
253
268
0 /* max_sequence_length*/ ));
254
-
255
269
decoder_feeds.push_back (expanded_decoder_attention_masks);
256
270
257
271
if (!past_present_share_buffer_) {
258
272
past_present_share_buffer_max_seq_len = 0 ;
259
273
}
260
274
261
- // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
262
- // of encoder.
263
- // When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
264
- // TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.
265
- // What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
266
- for (size_t j = static_cast <size_t >(2 ) - has_hidden_state_; j < encoder_fetches.size (); j++) {
267
- if (j == 1 ) {
268
- ORT_RETURN_IF (has_hidden_state_ == false , " Invalid hidden_states expension: has_hidden_state_ == false" );
269
- OrtValue expanded_hidden_states;
270
- if (is_output_float16_) {
271
- ORT_RETURN_IF_ERROR (expand_buffer_float16_func (stream,
272
- encoder_fetches[j],
273
- num_beam,
274
- allocator,
275
- expanded_hidden_states,
276
- false ,
277
- 0 /* max_sequence_length*/ ));
278
- } else {
279
- ORT_RETURN_IF_ERROR (expand_buffer_float_func (stream,
280
- encoder_fetches[j],
281
- num_beam,
282
- allocator,
283
- expanded_hidden_states,
284
- false ,
285
- 0 /* max_sequence_length*/ ));
286
- }
287
- decoder_feeds.push_back (expanded_hidden_states);
288
- } else {
275
+ // macro to expand encoder outputs and append to decoder feeds.
276
+ #define ADD_DECODER_FEED (encoder_output, is_dynamic_kv_cache ) \
277
+ OrtValue expanded; \
278
+ if (is_output_float16_) { \
279
+ ORT_RETURN_IF_ERROR (expand_buffer_float16_func (stream, encoder_output, num_beam, allocator, expanded, false , \
280
+ is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0 )); \
281
+ } else { \
282
+ ORT_RETURN_IF_ERROR (expand_buffer_float_func (stream, encoder_output, num_beam, allocator, expanded, false , \
283
+ is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0 )); \
284
+ } \
285
+ decoder_feeds.push_back (expanded);
286
+
287
+ // The encoder_hidden_states is copied from the second output of encoder.
288
+ if (has_hidden_state_) {
289
+ ADD_DECODER_FEED (encoder_fetches[1 ], false );
290
+ }
291
+
292
+ // New format of encoder has only cross outputs.
293
+ bool is_new_format = (static_cast <int >(encoder_fetches.size ()) == 2 * num_layers);
294
+ if (is_new_format) {
295
+ for (int i = 0 ; i < 2 * num_layers; i++) {
296
+ // cross shape is (batch_size, num_heads, encode_sequence_length, head_size)
297
+ const TensorShape& cross_shape = encoder_fetches[0 ].Get <Tensor>().Shape ();
298
+ ORT_ENFORCE (cross_shape.NumDimensions () == 4 );
299
+
300
+ // Shape for kv cache: (batch_size * num_beam, num_heads, max_seq_len, head_size)
301
+ int64_t cache_dims[4 ] = {0 };
302
+ cross_shape.CopyDims (cache_dims, cross_shape.NumDimensions ());
303
+ cache_dims[0 ] *= num_beam;
304
+ cache_dims[2 ] = past_present_share_buffer_max_seq_len;
305
+ TensorShape expanded_shape (&cache_dims[0 ], cross_shape.NumDimensions ());
306
+
307
+ MLDataType element_type = encoder_fetches[0 ].Get <Tensor>().DataType ();
308
+ OrtValue past;
309
+ Tensor::InitOrtValue (element_type, expanded_shape, allocator, past);
310
+ decoder_feeds.push_back (past);
311
+ }
312
+
313
+ // Add cross inputs from encoder output.
314
+ for (size_t j = 0 ; j < encoder_fetches.size (); j++) {
315
+ ADD_DECODER_FEED (encoder_fetches[j], false );
316
+ }
317
+ } else {
318
+ for (size_t j = 1 + has_hidden_state_; j < encoder_fetches.size (); j++) {
289
319
// past key/value for cross attention does not need to be initialized with max_seq_len since they are static.
290
- bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast <size_t >(num_layers);
291
-
292
- OrtValue expanded_cache;
293
- if (is_output_float16_) {
294
- ORT_RETURN_IF_ERROR (expand_buffer_float16_func (stream,
295
- encoder_fetches[j],
296
- num_beam,
297
- allocator,
298
- expanded_cache,
299
- false ,
300
- use_max_seq_len ? past_present_share_buffer_max_seq_len : 0 ));
301
- } else {
302
- ORT_RETURN_IF_ERROR (expand_buffer_float_func (stream,
303
- encoder_fetches[j],
304
- num_beam,
305
- allocator,
306
- expanded_cache,
307
- false ,
308
- use_max_seq_len ? past_present_share_buffer_max_seq_len : 0 ));
309
- }
310
- decoder_feeds.push_back (expanded_cache);
320
+ bool is_dynamic_kv_cache = (j - first_past_input_index_) < 2 * static_cast <size_t >(num_layers);
321
+ ADD_DECODER_FEED (encoder_fetches[j], is_dynamic_kv_cache);
311
322
}
312
323
}
313
324
314
- // TODO: This part shares the similar logic with CreateInitialFeeds() in subgraph_gpt.cc. We should refactor it.
315
325
if (past_present_share_buffer_) {
316
- // Past sequence length feed
317
- ORT_RETURN_IF_ERROR (AppendPastSequenceLength (decoder_feeds, cpu_allocator, 1 ));
326
+ // Past sequence length set to 0
327
+ ORT_RETURN_IF_ERROR (AppendPastSequenceLength (decoder_feeds, cpu_allocator, is_new_format ? 0 : 1 ));
318
328
// Add beam search specific inputs
319
329
if (need_cache_indir) {
320
330
const int64_t batch_size = static_cast <int64_t >(batch_beam_size / num_beam);
0 commit comments