@@ -141,14 +141,23 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i
141
141
}
142
142
143
143
// Create inputs for decoder from the following data sources:
144
- // encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
145
- // encoder fetches: logits,
146
- // encoder_hidden_states,
147
- // present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
148
- // decoder_feeds: input_ids,
149
- // encoder_attention_mask,
150
- // encoder_hidden_states,
151
- // present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
144
+ // New format:
145
+ // encoder feeds: encoder_input_ids, encoder_attention_mask
146
+ // encoder fetches: present_key_cross_0, present_value_cross_0, ...
147
+ // decoder_feeds: input_ids, encoder_attention_mask,
148
+ // present_key_self_0, present_value_self_0, ...,
149
+ // present_key_cross_0, present_value_cross_0, ...
150
+
151
+ // Old format:
152
+ // encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
153
+ // encoder fetches: logits, encoder_hidden_states,
154
+ // present_key_self_0, present_value_self_0, ...,
155
+ // present_key_cross_0, present_value_cross_0, ...
156
+ // decoder_feeds: input_ids, encoder_input_ids (optiona), encoder_attention_mask, encoder_hidden_states,
157
+ // present_key_self_0, present_value_self_0, ...,
158
+ // present_key_cross_0, present_value_cross_0, ...
159
+ // past_seq_len (optional), num_beams (optional), cache_indirection (optional)
160
+
152
161
Status T5DecoderSubgraph::CreateInitialFeeds (
153
162
AllocatorPtr cpu_allocator,
154
163
gsl::span<const int32_t > beam_next_tokens,
@@ -173,33 +182,31 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
173
182
// Allocate subgraph inputs from same device as inputs of encoder subgraph.
174
183
AllocatorPtr allocator = session_state_->GetAllocator (encoder_feeds[0 ].Get <Tensor>().Location ());
175
184
185
+ int batch_beam_size = static_cast <int >(encoder_fetches[0 ].Get <Tensor>().Shape ()[0 ]) * num_beam;
186
+
176
187
// Copy beam next tokens in CPU to input_ids in provider device (CPU for CPU EP, or GPU for CUDA EP).
177
- int batch_beam_size = static_cast <int >(beam_next_tokens.size ());
178
188
int sequence_length = !use_sequence_as_input_ids ? 1 : cur_len;
179
189
int64_t dims[] = {batch_beam_size, sequence_length};
180
190
TensorShape input_ids_shape (&dims[0 ], 2 );
181
191
OrtValue input_ids;
182
192
Tensor::InitOrtValue (DataTypeImpl::GetType<int32_t >(), input_ids_shape, allocator, input_ids);
183
- int32_t * input_ids_data = input_ids.GetMutable <Tensor>()->MutableData <int32_t >();
184
- AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
185
- size_t total_size = static_cast <size_t >(cur_len) * static_cast <size_t >(batch_beam_size);
186
- size_t total_size_bytes = total_size * sizeof (int );
187
- auto seq_copy = IAllocator::MakeUniquePtr<int >(buffer_allocator, total_size_bytes, false , stream);
188
- int * seq_copy_ptr = seq_copy.get ();
189
-
190
- if (!use_sequence_as_input_ids_) {
193
+
194
+ // Prepare data for input_ids.
195
+ if (!use_sequence_as_input_ids_) { // use next tokens for input_ids. This is for Whisper model.
196
+ int batch_beam_size = static_cast <int >(beam_next_tokens.size ());
191
197
ORT_RETURN_IF_ERROR (device_copy_int32_func (
192
198
input_ids.GetMutable <Tensor>()->MutableDataAsSpan <int32_t >(),
193
199
beam_next_tokens,
194
200
stream,
195
201
DeviceCopyDirection::hostToDevice));
196
202
} else {
203
+ int32_t * input_ids_data = input_ids.GetMutable <Tensor>()->MutableData <int32_t >();
197
204
if (use_cuda) {
198
205
auto sequences_buffer = sequences.GetCurrentDeviceSequences ();
199
206
for (int i = 0 ; i < batch_beam_size; i++) {
200
- size_t batch_beam_stride = static_cast <size_t >(i) * static_cast <size_t >(sequences.GetMaxLength ());
207
+ size_t offset = static_cast <size_t >(i) * static_cast <size_t >(sequences.GetMaxLength ());
201
208
int seq_size = sequences.GetSequenceLength ();
202
- gsl::span<const int32_t > sequence = sequences_buffer.subspan (batch_beam_stride , seq_size);
209
+ gsl::span<const int32_t > sequence = sequences_buffer.subspan (offset , seq_size);
203
210
gsl::span<int > temp_input (input_ids_data + static_cast <ptrdiff_t >(i) * seq_size, seq_size);
204
211
ORT_RETURN_IF_ERROR (device_copy_int32_func (
205
212
temp_input,
@@ -208,6 +215,13 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
208
215
DeviceCopyDirection::deviceToDevice));
209
216
}
210
217
} else {
218
+ size_t total_size = static_cast <size_t >(cur_len) * static_cast <size_t >(batch_beam_size);
219
+ size_t total_size_bytes = total_size * sizeof (int );
220
+ AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
221
+ // TODO: not need extra buffer. Copy directly to input_ids_data instead like the user_cuda above.
222
+ auto seq_copy = IAllocator::MakeUniquePtr<int >(buffer_allocator, total_size_bytes, false , stream);
223
+ int * seq_copy_ptr = seq_copy.get ();
224
+
211
225
const size_t cur_len_bytes = cur_len * sizeof (int );
212
226
for (int i = 0 ; i < batch_beam_size; i++) {
213
227
gsl::span<const int32_t > sequence = sequences.GetSequence (i);
@@ -227,9 +241,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
227
241
228
242
// The ordering is the same as used in Setup.
229
243
decoder_feeds.reserve (static_cast <size_t >(num_subgraph_inputs) + static_cast <size_t >(num_implicit_inputs));
244
+
245
+ // input 0: input_ids
230
246
decoder_feeds.push_back (input_ids);
231
247
232
- if (has_encoder_input_ids_) {
248
+ if (has_encoder_input_ids_) { // encoder_input_ids is optional
233
249
// The encoder_input_ids is copied from the first input of encoder.
234
250
OrtValue expanded_encoder_input_ids;
235
251
ORT_RETURN_IF_ERROR (expand_buffer_int32_func (stream,
@@ -251,70 +267,64 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
251
267
expanded_decoder_attention_masks,
252
268
false ,
253
269
0 /* max_sequence_length*/ ));
254
-
255
270
decoder_feeds.push_back (expanded_decoder_attention_masks);
256
271
257
272
if (!past_present_share_buffer_) {
258
273
past_present_share_buffer_max_seq_len = 0 ;
259
274
}
260
275
261
- // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
262
- // of encoder.
263
- // When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
264
- // TODO - probably more robust to introduce a encoder_out/decoder_in mapping instead of relying on positions.
265
- // What happens if encoder_hidden_states is present in the encoder_fetches but not in the decoder_feeds?
266
- for (size_t j = static_cast <size_t >(2 ) - has_hidden_state_; j < encoder_fetches.size (); j++) {
267
- if (j == 1 ) {
268
- ORT_RETURN_IF (has_hidden_state_ == false , " Invalid hidden_states expension: has_hidden_state_ == false" );
269
- OrtValue expanded_hidden_states;
270
- if (is_output_float16_) {
271
- ORT_RETURN_IF_ERROR (expand_buffer_float16_func (stream,
272
- encoder_fetches[j],
273
- num_beam,
274
- allocator,
275
- expanded_hidden_states,
276
- false ,
277
- 0 /* max_sequence_length*/ ));
278
- } else {
279
- ORT_RETURN_IF_ERROR (expand_buffer_float_func (stream,
280
- encoder_fetches[j],
281
- num_beam,
282
- allocator,
283
- expanded_hidden_states,
284
- false ,
285
- 0 /* max_sequence_length*/ ));
286
- }
287
- decoder_feeds.push_back (expanded_hidden_states);
288
- } else {
276
+ // lamda function to expand encoder outputs and append to decoder feeds.
277
+ #define ADD_DECODER_FEED (encoder_output, is_dynamic_kv_cache ) \
278
+ OrtValue expanded; \
279
+ if (is_output_float16_) { \
280
+ ORT_RETURN_IF_ERROR (expand_buffer_float16_func (stream, encoder_output, num_beam, allocator, expanded, false , \
281
+ is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0 )); \
282
+ } else { \
283
+ ORT_RETURN_IF_ERROR (expand_buffer_float_func (stream, encoder_output, num_beam, allocator, expanded, false , \
284
+ is_dynamic_kv_cache ? past_present_share_buffer_max_seq_len : 0 )); \
285
+ } \
286
+ decoder_feeds.push_back (expanded);
287
+
288
+ // The encoder_hidden_states is copied from the second output of encoder.
289
+ if (has_hidden_state_) {
290
+ ADD_DECODER_FEED (encoder_fetches[1 ], false );
291
+ }
292
+
293
+ // New format of encoder has only cross outputs.
294
+ if (static_cast <int >(encoder_fetches.size ()) == 2 * num_layers) {
295
+ for (int i = 0 ; i < 2 * num_layers; i++) {
296
+ // cross shape is (batch_size, num_heads, encode_sequence_length, head_size)
297
+ const TensorShape& cross_shape = encoder_fetches[0 ].Get <Tensor>().Shape ();
298
+ ORT_ENFORCE (cross_shape.NumDimensions () == 4 );
299
+
300
+ // Shape for kv cache: (batch_size * num_beam, num_heads, max_seq_len, head_size)
301
+ int64_t dims[4 ] = {0 };
302
+ cross_shape.CopyDims (dims, cross_shape.NumDimensions ());
303
+ dims[0 ] *= num_beam;
304
+ dims[2 ] = past_present_share_buffer_max_seq_len;
305
+ TensorShape expanded_shape (&dims[0 ], cross_shape.NumDimensions ());
306
+
307
+ MLDataType element_type = encoder_fetches[0 ].Get <Tensor>().DataType ();
308
+ OrtValue past;
309
+ Tensor::InitOrtValue (element_type, expanded_shape, allocator, past);
310
+ decoder_feeds.push_back (past);
311
+ }
312
+
313
+ // Add cross inputs from encoder output.
314
+ for (size_t j = 0 ; j < encoder_fetches.size (); j++) {
315
+ ADD_DECODER_FEED (encoder_fetches[j], false );
316
+ }
317
+ } else {
318
+ for (size_t j = 1 + has_hidden_state_; j < encoder_fetches.size (); j++) {
289
319
// past key/value for cross attention does not need to be initialized with max_seq_len since they are static.
290
- bool use_max_seq_len = (j - first_past_input_index_) < 2 * static_cast <size_t >(num_layers);
291
-
292
- OrtValue expanded_cache;
293
- if (is_output_float16_) {
294
- ORT_RETURN_IF_ERROR (expand_buffer_float16_func (stream,
295
- encoder_fetches[j],
296
- num_beam,
297
- allocator,
298
- expanded_cache,
299
- false ,
300
- use_max_seq_len ? past_present_share_buffer_max_seq_len : 0 ));
301
- } else {
302
- ORT_RETURN_IF_ERROR (expand_buffer_float_func (stream,
303
- encoder_fetches[j],
304
- num_beam,
305
- allocator,
306
- expanded_cache,
307
- false ,
308
- use_max_seq_len ? past_present_share_buffer_max_seq_len : 0 ));
309
- }
310
- decoder_feeds.push_back (expanded_cache);
320
+ bool is_dynamic_kv_cache = (j - first_past_input_index_) < 2 * static_cast <size_t >(num_layers);
321
+ ADD_DECODER_FEED (encoder_fetches[j], is_dynamic_kv_cache);
311
322
}
312
323
}
313
324
314
- // TODO: This part shares the similar logic with CreateInitialFeeds() in subgraph_gpt.cc. We should refactor it.
315
325
if (past_present_share_buffer_) {
316
- // Past sequence length feed
317
- ORT_RETURN_IF_ERROR (AppendPastSequenceLength (decoder_feeds, cpu_allocator, 1 ));
326
+ // Past sequence length set to 0
327
+ ORT_RETURN_IF_ERROR (AppendPastSequenceLength (decoder_feeds, cpu_allocator, 0 ));
318
328
// Add beam search specific inputs
319
329
if (need_cache_indir) {
320
330
const int64_t batch_size = static_cast <int64_t >(batch_beam_size / num_beam);
0 commit comments