Skip to content

Update T5 Onnx Export and Optimization #23949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
@@ -139,13 +139,19 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();

if (parameters_->decoder_start_token_id < 0) {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
if (!t5_encoder_subgraph_->HasLogitsOutput()) {
// New format requires start token id.
ORT_ENFORCE(parameters_->decoder_start_token_id >= 0);
} else {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
if (parameters_->decoder_start_token_id < 0) {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
} else {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
}
}

} else if (attribute_name == "decoder") {
ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
71 changes: 47 additions & 24 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
Original file line number Diff line number Diff line change
@@ -51,7 +51,13 @@ class BeamSearchT5 : public BeamSearchBase<T> {
expand_buffer_int32_func_(expand_buffer_int32_func),
expand_buffer_float_func_(expand_buffer_float_func),
expand_buffer_float16_func_(expand_buffer_float16_func),
create_beam_scorer_func_(create_beam_scorer_func) {}
create_beam_scorer_func_(create_beam_scorer_func) {
// When decoder uses encoder_hidden_state, make sure the encoder outputs it.
if (decoder_subgraph_.UseEncoderHiddenState()) {
ORT_ENFORCE(encoder_subgraph_.subgraph_output_names[1] == "encoder_hidden_states");
}
ORT_ENFORCE(encoder_subgraph_.num_layers == decoder_subgraph_.num_layers);
}

#ifdef USE_CUDA
Status InitializeCuda(
@@ -160,7 +166,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
this->create_encoder_inputs_func_,
this->add_to_feeds_func_,
buffer,
decoder_input_ids,
decoder_input_ids, // new format does not use decoder_input_ids in encoder, it is still initialized here when decoder_start_token_id >= 0.
this->ort_stream_));

#ifdef DEBUG_NODE_INPUTS_OUTPUTS
@@ -233,35 +239,47 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches

std::vector<OrtValue> decoder_fetches;

if (current_length + 1 < parameters->max_length) {
// When encoder outputs logits (in old format), we need get the next token from logits.
if (current_length + 1 < parameters->max_length && encoder_subgraph_.HasLogitsOutput()) {
++iteration_counter;
ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0],
const OrtValue& logits = encoder_fetches[0];
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
beam_next_tokens,
beam_state,
cpu_state,
iteration_counter));
++current_length; // Increase sequence length after a new token is generated.
}

ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(this->cpu_allocator_,
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
this->implicit_inputs_,
encoder_feeds,
encoder_fetches,
decoder_feeds,
this->device_copy_int32_func_,
this->expand_buffer_int32_func_,
this->expand_buffer_float_func_,
this->expand_buffer_float16_func_,
parameters->num_beams,
this->ort_stream_,
decoder_subgraph_.UseSequenceAsInputIds(),
current_length,
cpu_state.sequences,
parameters->max_length,
decoder_subgraph_.has_decoder_masked_attention_,
this->cuda_device_prop_ != nullptr));
if (current_length < parameters->max_length) {
// when no logits, copy sequence (filled with start token IDs) to input_ids for decoder.
bool copy_sequence_to_input_ids = decoder_subgraph_.UseSequenceAsInputIds() || !encoder_subgraph_.HasLogitsOutput();
if (copy_sequence_to_input_ids) {
ORT_ENFORCE(current_length == cpu_state.sequences.GetSequenceLength());
}

// Generate inputs for next decoder subgraph call.
ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(
this->cpu_allocator_,
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
this->implicit_inputs_,
encoder_feeds,
encoder_fetches,
decoder_feeds,
this->device_copy_int32_func_,
this->expand_buffer_int32_func_,
this->expand_buffer_float_func_,
this->expand_buffer_float16_func_,
parameters->num_beams,
this->ort_stream_,
copy_sequence_to_input_ids,
cpu_state.sequences,
parameters->max_length,
decoder_subgraph_.has_decoder_masked_attention_,
this->cuda_device_prop_ != nullptr));

if (decoder_subgraph_.past_present_share_buffer_) {
// Configure buffer sharing of past and present kv cache.
decoder_fetches.reserve(static_cast<size_t>(decoder_subgraph_.GetFirstPresentOutputIndex()) +
2 * static_cast<size_t>(decoder_subgraph_.num_layers));
decoder_fetches.resize(decoder_subgraph_.GetFirstPresentOutputIndex(), OrtValue());
@@ -299,14 +317,19 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches

while (current_length < parameters->max_length) {
iteration_counter++;

#ifdef DEBUG_GENERATION
auto cur_len = std::to_string(current_length);
dumper->Print("***CurrentLength", cur_len, true);
dumper->Print(::onnxruntime::MakeString("Iteration=", iteration_counter,
", CurrentLength=", current_length,
", num_layers=", decoder_subgraph_.num_layers,
", decoder_feeds=", decoder_feeds.size(),
", start_token_id=", parameters->decoder_start_token_id));

for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) {
dumper->Print("decoder_feeds", i, true);
dumper->Print("", decoder_feeds[i]);
}

for (int i = 0; i < decoder_subgraph_.num_layers; i++) {
int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i;
int self_value_idx = self_key_idx + 1;
18 changes: 8 additions & 10 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
Original file line number Diff line number Diff line change
@@ -36,12 +36,9 @@ Subgraph::Subgraph(
auto& subgraph_inputs = subgraph.GetInputs();
auto& subgraph_outputs = subgraph.GetOutputs();

// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
// outputs: logits, present_0, present_1, ...
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());

// CheckSubgraph will verify inputs and outputs later.
subgraph_input_names.reserve(num_subgraph_inputs);
for (int i = 0; i < num_subgraph_inputs; ++i) {
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
@@ -68,10 +65,9 @@ Status Subgraph::Setup(const SessionState& session_state,
InlinedVector<std::string_view> feed_names;
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));

// Use the first output (logits) to find device location.
// Use the first output to find device location.
const OrtDevice& default_location = utils::FindDeviceForValue(subgraph_session_state, subgraph_output_names[0]);

// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());

const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
@@ -174,13 +170,15 @@ Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shap
}

// Logits shape is like (batch_size, seq_len, vocabulary_size)
ORT_RETURN_IF(logits_shape->dim_size() != 3,
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
if (logits_shape != nullptr) {
ORT_RETURN_IF(logits_shape->dim_size() != 3,
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());

ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for vocabulary size");

this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
}

return Status::OK();
}
Loading