/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/grammar/lexer.h" #include #include "annotator/types.h" #include "utils/zlib/zlib.h" #include "utils/zlib/zlib_regex.h" namespace libtextclassifier3::grammar { namespace { inline bool CheckMemoryUsage(const Matcher* matcher) { // The maximum memory usage for matching. constexpr int kMaxMemoryUsage = 1 << 20; return matcher->ArenaSize() <= kMaxMemoryUsage; } Match* CheckedAddMatch(const Nonterm nonterm, const CodepointSpan codepoint_span, const int match_offset, const int16 type, Matcher* matcher) { if (nonterm == kUnassignedNonterm || !CheckMemoryUsage(matcher)) { return nullptr; } return matcher->AllocateAndInitMatch(nonterm, codepoint_span, match_offset, type); } void CheckedEmit(const Nonterm nonterm, const CodepointSpan codepoint_span, const int match_offset, int16 type, Matcher* matcher) { if (nonterm != kUnassignedNonterm && CheckMemoryUsage(matcher)) { matcher->AddMatch(matcher->AllocateAndInitMatch( nonterm, codepoint_span, match_offset, type)); } } int MapCodepointToTokenPaddingIfPresent( const std::unordered_map& token_alignment, const int start) { const auto it = token_alignment.find(start); if (it != token_alignment.end()) { return it->second; } return start; } } // namespace Lexer::Lexer(const UniLib* unilib, const RulesSet* rules) : unilib_(*unilib), rules_(rules), regex_annotators_(BuildRegexAnnotator(unilib_, rules)) {} std::vector Lexer::BuildRegexAnnotator( const UniLib& unilib, const RulesSet* rules) const { std::vector result; if (rules->regex_annotator() != nullptr) { std::unique_ptr decompressor = ZlibDecompressor::Instance(); result.reserve(rules->regex_annotator()->size()); for (const RulesSet_::RegexAnnotator* regex_annotator : *rules->regex_annotator()) { result.push_back( {UncompressMakeRegexPattern(unilib_, regex_annotator->pattern(), regex_annotator->compressed_pattern(), rules->lazy_regex_compilation(), decompressor.get()), regex_annotator->nonterminal()}); } } return result; } void Lexer::Emit(const Symbol& symbol, const RulesSet_::Nonterminals* nonterms, Matcher* matcher) const { switch (symbol.type) { case Symbol::Type::TYPE_MATCH: { // Just emit the match. matcher->AddMatch(symbol.match); return; } case Symbol::Type::TYPE_DIGITS: { // Emit if used by the rules. CheckedEmit(nonterms->digits_nt(), symbol.codepoint_span, symbol.match_offset, Match::kDigitsType, matcher); // Emit if used by the rules. if (nonterms->n_digits_nt() != nullptr) { const int num_digits = symbol.codepoint_span.second - symbol.codepoint_span.first; if (num_digits <= nonterms->n_digits_nt()->size()) { CheckedEmit(nonterms->n_digits_nt()->Get(num_digits - 1), symbol.codepoint_span, symbol.match_offset, Match::kDigitsType, matcher); } } break; } case Symbol::Type::TYPE_TERM: { // Emit if used by the rules. if (nonterms->uppercase_token_nt() != 0 && unilib_.IsUpperText( UTF8ToUnicodeText(symbol.lexeme, /*do_copy=*/false))) { CheckedEmit(nonterms->uppercase_token_nt(), symbol.codepoint_span, symbol.match_offset, Match::kTokenType, matcher); } break; } default: break; } // Emit the token as terminal. if (CheckMemoryUsage(matcher)) { matcher->AddTerminal(symbol.codepoint_span, symbol.match_offset, symbol.lexeme); } // Emit if used by rules. CheckedEmit(nonterms->token_nt(), symbol.codepoint_span, symbol.match_offset, Match::kTokenType, matcher); } Lexer::Symbol::Type Lexer::GetSymbolType( const UnicodeText::const_iterator& it) const { if (unilib_.IsPunctuation(*it)) { return Symbol::Type::TYPE_PUNCTUATION; } else if (unilib_.IsDigit(*it)) { return Symbol::Type::TYPE_DIGITS; } else { return Symbol::Type::TYPE_TERM; } } void Lexer::ProcessToken(const StringPiece value, const int prev_token_end, const CodepointSpan codepoint_span, std::vector* symbols) const { // Possibly split token. UnicodeText token_unicode = UTF8ToUnicodeText(value.data(), value.size(), /*do_copy=*/false); int last_end = prev_token_end; auto token_end = token_unicode.end(); auto it = token_unicode.begin(); Symbol::Type type = GetSymbolType(it); CodepointIndex sub_token_start = codepoint_span.first; while (it != token_end) { auto next = std::next(it); int num_codepoints = 1; Symbol::Type next_type; while (next != token_end) { next_type = GetSymbolType(next); if (type == Symbol::Type::TYPE_PUNCTUATION || next_type != type) { break; } ++next; ++num_codepoints; } symbols->push_back(Symbol{ type, CodepointSpan{sub_token_start, sub_token_start + num_codepoints}, /*match_offset=*/last_end, /*lexeme=*/ StringPiece(it.utf8_data(), next.utf8_data() - it.utf8_data())}); last_end = sub_token_start + num_codepoints; it = next; type = next_type; sub_token_start = last_end; } } void Lexer::Process(const UnicodeText& text, const std::vector& tokens, const std::vector* annotations, Matcher* matcher) const { return Process(text, tokens.begin(), tokens.end(), annotations, matcher); } void Lexer::Process(const UnicodeText& text, const std::vector::const_iterator& begin, const std::vector::const_iterator& end, const std::vector* annotations, Matcher* matcher) const { if (begin == end) { return; } const RulesSet_::Nonterminals* nonterminals = rules_->nonterminals(); // Initialize processing of new text. CodepointIndex prev_token_end = 0; std::vector symbols; matcher->Reset(); // The matcher expects the terminals and non-terminals it received to be in // non-decreasing end-position order. The sorting above makes sure the // pre-defined matches adhere to that order. // Ideally, we would just have to emit a predefined match whenever we see that // the next token we feed would be ending later. // But as we implicitly ignore whitespace, we have to merge preceding // whitespace to the match start so that tokens and non-terminals fed appear // as next to each other without whitespace. // We keep track of real token starts and precending whitespace in // `token_match_start`, so that we can extend a predefined match's start to // include the preceding whitespace. std::unordered_map token_match_start; // Add start symbols. if (Match* match = CheckedAddMatch(nonterminals->start_nt(), CodepointSpan{0, 0}, /*match_offset=*/0, Match::kBreakType, matcher)) { symbols.push_back(Symbol(match)); } if (Match* match = CheckedAddMatch(nonterminals->wordbreak_nt(), CodepointSpan{0, 0}, /*match_offset=*/0, Match::kBreakType, matcher)) { symbols.push_back(Symbol(match)); } for (auto token_it = begin; token_it != end; token_it++) { const Token& token = *token_it; // Record match starts for token boundaries, so that we can snap pre-defined // matches to it. if (prev_token_end != token.start) { token_match_start[token.start] = prev_token_end; } ProcessToken(token.value, /*prev_token_end=*/prev_token_end, CodepointSpan{token.start, token.end}, &symbols); prev_token_end = token.end; // Add word break symbol if used by the grammar. if (Match* match = CheckedAddMatch( nonterminals->wordbreak_nt(), CodepointSpan{token.end, token.end}, /*match_offset=*/token.end, Match::kBreakType, matcher)) { symbols.push_back(Symbol(match)); } } // Add end symbol if used by the grammar. if (Match* match = CheckedAddMatch( nonterminals->end_nt(), CodepointSpan{prev_token_end, prev_token_end}, /*match_offset=*/prev_token_end, Match::kBreakType, matcher)) { symbols.push_back(Symbol(match)); } // Add matches based on annotations. auto annotation_nonterminals = nonterminals->annotation_nt(); if (annotation_nonterminals != nullptr && annotations != nullptr) { for (const AnnotatedSpan& annotated_span : *annotations) { const ClassificationResult& classification = annotated_span.classification.front(); if (auto entry = annotation_nonterminals->LookupByKey( classification.collection.c_str())) { AnnotationMatch* match = matcher->AllocateAndInitMatch( entry->value(), annotated_span.span, /*match_offset=*/ MapCodepointToTokenPaddingIfPresent(token_match_start, annotated_span.span.first), Match::kAnnotationMatch); match->annotation = &classification; symbols.push_back(Symbol(match)); } } } // Add regex annotator matches for the range covered by the tokens. for (const RegexAnnotator& regex_annotator : regex_annotators_) { std::unique_ptr regex_matcher = regex_annotator.pattern->Matcher(UnicodeText::Substring( text, begin->start, prev_token_end, /*do_copy=*/false)); int status = UniLib::RegexMatcher::kNoError; while (regex_matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { const CodepointSpan span = { regex_matcher->Start(0, &status) + begin->start, regex_matcher->End(0, &status) + begin->start}; if (Match* match = CheckedAddMatch(regex_annotator.nonterm, span, /*match_offset=*/ MapCodepointToTokenPaddingIfPresent( token_match_start, span.first), Match::kUnknownType, matcher)) { symbols.push_back(Symbol(match)); } } } std::sort(symbols.begin(), symbols.end(), [](const Symbol& a, const Symbol& b) { // Sort by increasing (end, start) position to guarantee the // matcher requirement that the tokens are fed in non-decreasing // end position order. return std::tie(a.codepoint_span.second, a.codepoint_span.first) < std::tie(b.codepoint_span.second, b.codepoint_span.first); }); // Emit symbols to matcher. for (const Symbol& symbol : symbols) { Emit(symbol, nonterminals, matcher); } // Finish the matching. matcher->Finish(); } } // namespace libtextclassifier3::grammar