/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/grammar/grammar-annotator.h" #include "annotator/feature-processor.h" #include "annotator/grammar/utils.h" #include "annotator/types.h" #include "utils/base/logging.h" #include "utils/grammar/callback-delegate.h" #include "utils/grammar/match.h" #include "utils/grammar/matcher.h" #include "utils/grammar/rules-utils.h" #include "utils/grammar/types.h" #include "utils/normalization.h" #include "utils/optional.h" #include "utils/utf8/unicodetext.h" namespace libtextclassifier3 { namespace { // Returns the unicode codepoint offsets in a utf8 encoded text. std::vector UnicodeCodepointOffsets( const UnicodeText& text) { std::vector offsets; for (auto it = text.begin(); it != text.end(); it++) { offsets.push_back(it); } offsets.push_back(text.end()); return offsets; } } // namespace class GrammarAnnotatorCallbackDelegate : public grammar::CallbackDelegate { public: explicit GrammarAnnotatorCallbackDelegate( const UniLib* unilib, const GrammarModel* model, const ReflectiveFlatbufferBuilder* entity_data_builder, const ModeFlag mode) : unilib_(*unilib), model_(model), entity_data_builder_(entity_data_builder), mode_(mode) {} // Handles a grammar rule match in the annotator grammar. void MatchFound(const grammar::Match* match, grammar::CallbackId type, int64 value, grammar::Matcher* matcher) override { switch (static_cast(type)) { case GrammarAnnotator::Callback::kRuleMatch: { HandleRuleMatch(match, /*rule_id=*/value); return; } default: grammar::CallbackDelegate::MatchFound(match, type, value, matcher); } } // Deduplicate and populate annotations from grammar matches. bool GetAnnotations(const std::vector& text, std::vector* annotations) const { for (const grammar::Derivation& candidate : grammar::DeduplicateDerivations(candidates_)) { // Check that assertions are fulfilled. if (!grammar::VerifyAssertions(candidate.match)) { continue; } if (!AddAnnotatedSpanFromMatch(text, candidate, annotations)) { return false; } } return true; } bool GetTextSelection(const std::vector& text, const CodepointSpan& selection, AnnotatedSpan* result) { std::vector selection_candidates; // Deduplicate and verify matches. auto maybe_interpretation = GetBestValidInterpretation( grammar::DeduplicateDerivations(GetOverlappingRuleMatches( selection, candidates_, /*only_exact_overlap=*/false))); if (!maybe_interpretation.has_value()) { return false; } const GrammarModel_::RuleClassificationResult* interpretation; const grammar::Match* match; std::tie(interpretation, match) = maybe_interpretation.value(); return InstantiateAnnotatedSpanFromInterpretation(text, interpretation, match, result); } // Provides a classification results from the grammar matches. bool GetClassification(const std::vector& text, const CodepointSpan& selection, ClassificationResult* classification) const { // Deduplicate and verify matches. auto maybe_interpretation = GetBestValidInterpretation( grammar::DeduplicateDerivations(GetOverlappingRuleMatches( selection, candidates_, /*only_exact_overlap=*/true))); if (!maybe_interpretation.has_value()) { return false; } // Instantiate result. const GrammarModel_::RuleClassificationResult* interpretation; const grammar::Match* match; std::tie(interpretation, match) = maybe_interpretation.value(); return InstantiateClassificationInterpretation(text, interpretation, match, classification); } private: // Handles annotation/selection/classification rule matches. void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) { if ((model_->rule_classification_result()->Get(rule_id)->enabled_modes() & mode_) != 0) { candidates_.push_back(grammar::Derivation{match, rule_id}); } } // Computes the selection boundaries from a grammar match. CodepointSpan MatchSelectionBoundaries( const grammar::Match* match, const GrammarModel_::RuleClassificationResult* classification) const { if (classification->capturing_group() == nullptr) { // Use full match as selection span. return match->codepoint_span; } // Set information from capturing matches. CodepointSpan span{kInvalidIndex, kInvalidIndex}; // Gather active capturing matches. std::unordered_map capturing_matches; for (const grammar::MappingMatch* match : grammar::SelectAllOfType( match, grammar::Match::kMappingMatch)) { capturing_matches[match->id] = match; } // Compute span boundaries. for (int i = 0; i < classification->capturing_group()->size(); i++) { auto it = capturing_matches.find(i); if (it == capturing_matches.end()) { // Capturing group is not active, skip. continue; } const CapturingGroup* group = classification->capturing_group()->Get(i); if (group->extend_selection()) { if (span.first == kInvalidIndex) { span = it->second->codepoint_span; } else { span.first = std::min(span.first, it->second->codepoint_span.first); span.second = std::max(span.second, it->second->codepoint_span.second); } } } return span; } // Filters out results that do not overlap with a reference span. std::vector GetOverlappingRuleMatches( const CodepointSpan& selection, const std::vector& candidates, const bool only_exact_overlap) const { std::vector result; for (const grammar::Derivation& candidate : candidates) { // Discard matches that do not match the selection. // Simple check. if (!SpansOverlap(selection, candidate.match->codepoint_span)) { continue; } // Compute exact selection boundaries (without assertions and // non-capturing parts). const CodepointSpan span = MatchSelectionBoundaries( candidate.match, model_->rule_classification_result()->Get(candidate.rule_id)); if (!SpansOverlap(selection, span) || (only_exact_overlap && span != selection)) { continue; } result.push_back(candidate); } return result; } // Returns the best valid interpretation of a set of candidate matches. Optional> GetBestValidInterpretation( const std::vector& candidates) const { const GrammarModel_::RuleClassificationResult* best_interpretation = nullptr; const grammar::Match* best_match = nullptr; for (const grammar::Derivation& candidate : candidates) { if (!grammar::VerifyAssertions(candidate.match)) { continue; } const GrammarModel_::RuleClassificationResult* rule_classification_result = model_->rule_classification_result()->Get(candidate.rule_id); if (best_interpretation == nullptr || best_interpretation->priority_score() < rule_classification_result->priority_score()) { best_interpretation = rule_classification_result; best_match = candidate.match; } } // No valid interpretation found. Optional> result; if (best_interpretation != nullptr) { result = {best_interpretation, best_match}; } return result; } // Instantiates an annotated span from a rule match and appends it to the // result. bool AddAnnotatedSpanFromMatch( const std::vector& text, const grammar::Derivation& candidate, std::vector* result) const { if (candidate.rule_id < 0 || candidate.rule_id >= model_->rule_classification_result()->size()) { TC3_LOG(INFO) << "Invalid rule id."; return false; } const GrammarModel_::RuleClassificationResult* interpretation = model_->rule_classification_result()->Get(candidate.rule_id); result->emplace_back(); return InstantiateAnnotatedSpanFromInterpretation( text, interpretation, candidate.match, &result->back()); } bool InstantiateAnnotatedSpanFromInterpretation( const std::vector& text, const GrammarModel_::RuleClassificationResult* interpretation, const grammar::Match* match, AnnotatedSpan* result) const { result->span = MatchSelectionBoundaries(match, interpretation); ClassificationResult classification; if (!InstantiateClassificationInterpretation(text, interpretation, match, &classification)) { return false; } result->classification.push_back(classification); return true; } // Instantiates a classification result from a rule match. bool InstantiateClassificationInterpretation( const std::vector& text, const GrammarModel_::RuleClassificationResult* interpretation, const grammar::Match* match, ClassificationResult* classification) const { classification->collection = interpretation->collection_name()->str(); classification->score = interpretation->target_classification_score(); classification->priority_score = interpretation->priority_score(); // Assemble entity data. if (entity_data_builder_ == nullptr) { return true; } std::unique_ptr entity_data = entity_data_builder_->NewRoot(); if (interpretation->serialized_entity_data() != nullptr) { entity_data->MergeFromSerializedFlatbuffer( StringPiece(interpretation->serialized_entity_data()->data(), interpretation->serialized_entity_data()->size())); } if (interpretation->entity_data() != nullptr) { entity_data->MergeFrom(reinterpret_cast( interpretation->entity_data())); } // Populate entity data from the capturing matches. if (interpretation->capturing_group() != nullptr) { // Gather active capturing matches. std::unordered_map capturing_matches; for (const grammar::MappingMatch* match : grammar::SelectAllOfType( match, grammar::Match::kMappingMatch)) { capturing_matches[match->id] = match; } for (int i = 0; i < interpretation->capturing_group()->size(); i++) { auto it = capturing_matches.find(i); if (it == capturing_matches.end()) { // Capturing group is not active, skip. continue; } const CapturingGroup* group = interpretation->capturing_group()->Get(i); // Add static entity data. if (group->serialized_entity_data() != nullptr) { entity_data->MergeFromSerializedFlatbuffer( StringPiece(interpretation->serialized_entity_data()->data(), interpretation->serialized_entity_data()->size())); } // Set entity field from captured text. if (group->entity_field_path() != nullptr) { const grammar::Match* capturing_match = it->second; StringPiece group_text = StringPiece( text[capturing_match->codepoint_span.first].utf8_data(), text[capturing_match->codepoint_span.second].utf8_data() - text[capturing_match->codepoint_span.first].utf8_data()); UnicodeText normalized_group_text = UTF8ToUnicodeText(group_text, /*do_copy=*/false); if (group->normalization_options() != nullptr) { normalized_group_text = NormalizeText( unilib_, group->normalization_options(), normalized_group_text); } if (!entity_data->ParseAndSet(group->entity_field_path(), normalized_group_text.ToUTF8String())) { TC3_LOG(ERROR) << "Could not set entity data from capturing match."; return false; } } } } if (entity_data && entity_data->HasExplicitlySetFields()) { classification->serialized_entity_data = entity_data->Serialize(); } return true; } const UniLib& unilib_; const GrammarModel* model_; const ReflectiveFlatbufferBuilder* entity_data_builder_; const ModeFlag mode_; // All annotation/selection/classification rule match candidates. // Grammar rule matches are recorded, deduplicated and then instantiated. std::vector candidates_; }; GrammarAnnotator::GrammarAnnotator( const UniLib* unilib, const GrammarModel* model, const ReflectiveFlatbufferBuilder* entity_data_builder) : unilib_(*unilib), model_(model), lexer_(unilib, model->rules()), tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())), entity_data_builder_(entity_data_builder), rules_locales_(grammar::ParseRulesLocales(model->rules())) {} bool GrammarAnnotator::Annotate(const std::vector& locales, const UnicodeText& text, std::vector* result) const { if (model_ == nullptr || model_->rules() == nullptr) { // Nothing to do. return true; } // Select locale matching rules. std::vector locale_rules = SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales); if (locale_rules.empty()) { // Nothing to do. return true; } // Run the grammar. GrammarAnnotatorCallbackDelegate callback_handler( &unilib_, model_, entity_data_builder_, /*mode=*/ModeFlag_ANNOTATION); grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules, &callback_handler); lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr, &matcher); // Populate results. return callback_handler.GetAnnotations(UnicodeCodepointOffsets(text), result); } bool GrammarAnnotator::SuggestSelection(const std::vector& locales, const UnicodeText& text, const CodepointSpan& selection, AnnotatedSpan* result) const { if (model_ == nullptr || model_->rules() == nullptr || selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) { // Nothing to do. return false; } // Select locale matching rules. std::vector locale_rules = SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales); if (locale_rules.empty()) { // Nothing to do. return true; } // Run the grammar. GrammarAnnotatorCallbackDelegate callback_handler( &unilib_, model_, entity_data_builder_, /*mode=*/ModeFlag_SELECTION); grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules, &callback_handler); lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr, &matcher); // Populate the result. return callback_handler.GetTextSelection(UnicodeCodepointOffsets(text), selection, result); } bool GrammarAnnotator::ClassifyText( const std::vector& locales, const UnicodeText& text, const CodepointSpan& selection, ClassificationResult* classification_result) const { if (model_ == nullptr || model_->rules() == nullptr || selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) { // Nothing to do. return false; } // Select locale matching rules. std::vector locale_rules = SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales); if (locale_rules.empty()) { // Nothing to do. return false; } // Run the grammar. GrammarAnnotatorCallbackDelegate callback_handler( &unilib_, model_, entity_data_builder_, /*mode=*/ModeFlag_CLASSIFICATION); grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules, &callback_handler); const std::vector tokens = tokenizer_.Tokenize(text); if (model_->context_left_num_tokens() == -1 && model_->context_right_num_tokens() == -1) { // Use all tokens. lexer_.Process(text, tokens, /*annotations=*/{}, &matcher); } else { TokenSpan context_span = CodepointSpanToTokenSpan( tokens, selection, /*snap_boundaries_to_containing_tokens=*/true); std::vector::const_iterator begin = tokens.begin(); std::vector::const_iterator end = tokens.begin(); if (model_->context_left_num_tokens() != -1) { std::advance(begin, std::max(0, context_span.first - model_->context_left_num_tokens())); } if (model_->context_right_num_tokens() == -1) { end = tokens.end(); } else { std::advance(end, std::min(static_cast(tokens.size()), context_span.second + model_->context_right_num_tokens())); } lexer_.Process(text, begin, end, /*annotations=*/nullptr, &matcher); } // Populate result. return callback_handler.GetClassification(UnicodeCodepointOffsets(text), selection, classification_result); } } // namespace libtextclassifier3