/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/lua-utils.h" #include #include "utils/flatbuffers.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace libtextclassifier3 { namespace { using testing::ElementsAre; using testing::Eq; using testing::FloatEq; std::string TestFlatbufferSchema() { // Creates a test schema for flatbuffer passing tests. // Cannot use the object oriented API here as that is not available for the // reflection schema. flatbuffers::FlatBufferBuilder schema_builder; std::vector> fields = { reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("float_field"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::Float), /*id=*/0, /*offset=*/4), reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("nested_field"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::Obj, /*element=*/reflection::None, /*index=*/0 /* self */), /*id=*/1, /*offset=*/6), reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("repeated_nested_field"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::Vector, /*element=*/reflection::Obj, /*index=*/0 /* self */), /*id=*/2, /*offset=*/8), reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("repeated_string_field"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::Vector, /*element=*/reflection::String), /*id=*/3, /*offset=*/10), reflection::CreateField( schema_builder, /*name=*/schema_builder.CreateString("string_field"), /*type=*/ reflection::CreateType(schema_builder, /*base_type=*/reflection::String), /*id=*/4, /*offset=*/12)}; std::vector> enums; std::vector> objects = { reflection::CreateObject( schema_builder, /*name=*/schema_builder.CreateString("TestData"), /*fields=*/ schema_builder.CreateVectorOfSortedTables(&fields))}; schema_builder.Finish(reflection::CreateSchema( schema_builder, schema_builder.CreateVectorOfSortedTables(&objects), schema_builder.CreateVectorOfSortedTables(&enums), /*(unused) file_ident=*/0, /*(unused) file_ext=*/0, /*root_table*/ objects[0])); return std::string( reinterpret_cast(schema_builder.GetBufferPointer()), schema_builder.GetSize()); } class LuaUtilsTest : public testing::Test, protected LuaEnvironment { protected: LuaUtilsTest() : serialized_flatbuffer_schema_(TestFlatbufferSchema()), schema_(flatbuffers::GetRoot( serialized_flatbuffer_schema_.data())), flatbuffer_builder_(schema_) { EXPECT_THAT(RunProtected([this] { LoadDefaultLibraries(); return LUA_OK; }), Eq(LUA_OK)); } void RunScript(StringPiece script) { EXPECT_THAT(luaL_loadbuffer(state_, script.data(), script.size(), /*name=*/nullptr), Eq(LUA_OK)); EXPECT_THAT( lua_pcall(state_, /*nargs=*/0, /*num_results=*/1, /*errfunc=*/0), Eq(LUA_OK)); } const std::string serialized_flatbuffer_schema_; const reflection::Schema* schema_; ReflectiveFlatbufferBuilder flatbuffer_builder_; }; TEST_F(LuaUtilsTest, HandlesVectors) { { PushVector(std::vector{1, 2, 3, 4, 5}); EXPECT_THAT(ReadVector(), ElementsAre(1, 2, 3, 4, 5)); } { PushVector(std::vector{"hello", "there"}); EXPECT_THAT(ReadVector(), ElementsAre("hello", "there")); } { PushVector(std::vector{true, true, false}); EXPECT_THAT(ReadVector(), ElementsAre(true, true, false)); } } TEST_F(LuaUtilsTest, HandlesVectorIterators) { { const std::vector elements = {1, 2, 3, 4, 5}; PushVectorIterator(&elements); EXPECT_THAT(ReadVector(), ElementsAre(1, 2, 3, 4, 5)); } { const std::vector elements = {"hello", "there"}; PushVectorIterator(&elements); EXPECT_THAT(ReadVector(), ElementsAre("hello", "there")); } { const std::vector elements = {true, true, false}; PushVectorIterator(&elements); EXPECT_THAT(ReadVector(), ElementsAre(true, true, false)); } } TEST_F(LuaUtilsTest, ReadsFlatbufferResults) { // Setup. RunScript(R"lua( return { float_field = 42.1, string_field = "hello there", -- Nested field. nested_field = { float_field = 64, string_field = "hello nested", }, -- Repeated fields. repeated_string_field = { "a", "bold", "one" }, repeated_nested_field = { { string_field = "a" }, { string_field = "b" }, { repeated_string_field = { "nested", "nested2" } }, }, } )lua"); // Read the flatbuffer. std::unique_ptr buffer = flatbuffer_builder_.NewRoot(); ReadFlatbuffer(/*index=*/-1, buffer.get()); const std::string serialized_buffer = buffer->Serialize(); // Check fields. As we do not have flatbuffer compiled generated code for the // ad hoc generated test schema, we have to read by manually using field // offsets. const flatbuffers::Table* flatbuffer_data = flatbuffers::GetRoot(serialized_buffer.data()); EXPECT_THAT(flatbuffer_data->GetField(/*field=*/4, /*defaultval=*/0), FloatEq(42.1)); EXPECT_THAT( flatbuffer_data->GetPointer(/*field=*/12) ->str(), "hello there"); // Read the nested field. const flatbuffers::Table* nested_field = flatbuffer_data->GetPointer(/*field=*/6); EXPECT_THAT(nested_field->GetField(/*field=*/4, /*defaultval=*/0), FloatEq(64)); EXPECT_THAT( nested_field->GetPointer(/*field=*/12)->str(), "hello nested"); // Read the repeated string field. auto repeated_strings = flatbuffer_data->GetPointer< flatbuffers::Vector>*>( /*field=*/10); EXPECT_THAT(repeated_strings->size(), Eq(3)); EXPECT_THAT(repeated_strings->GetAsString(0)->str(), Eq("a")); EXPECT_THAT(repeated_strings->GetAsString(1)->str(), Eq("bold")); EXPECT_THAT(repeated_strings->GetAsString(2)->str(), Eq("one")); // Read the repeated nested field. auto repeated_nested_fields = flatbuffer_data->GetPointer< flatbuffers::Vector>*>( /*field=*/8); EXPECT_THAT(repeated_nested_fields->size(), Eq(3)); EXPECT_THAT(repeated_nested_fields->Get(0) ->GetPointer(/*field=*/12) ->str(), "a"); EXPECT_THAT(repeated_nested_fields->Get(1) ->GetPointer(/*field=*/12) ->str(), "b"); } TEST_F(LuaUtilsTest, HandlesSimpleFlatbufferFields) { // Create test flatbuffer. std::unique_ptr buffer = flatbuffer_builder_.NewRoot(); buffer->Set("float_field", 42.f); const std::string serialized_buffer = buffer->Serialize(); PushFlatbuffer(schema_, flatbuffers::GetRoot( serialized_buffer.data())); lua_setglobal(state_, "arg"); // Setup. RunScript(R"lua( return arg.float_field )lua"); EXPECT_THAT(Read(), FloatEq(42)); } TEST_F(LuaUtilsTest, HandlesRepeatedFlatbufferFields) { // Create test flatbuffer. std::unique_ptr buffer = flatbuffer_builder_.NewRoot(); RepeatedField* repeated_field = buffer->Repeated("repeated_string_field"); repeated_field->Add("this"); repeated_field->Add("is"); repeated_field->Add("a"); repeated_field->Add("test"); const std::string serialized_buffer = buffer->Serialize(); PushFlatbuffer(schema_, flatbuffers::GetRoot( serialized_buffer.data())); lua_setglobal(state_, "arg"); // Return flatbuffer repeated field as vector. RunScript(R"lua( return arg.repeated_string_field )lua"); EXPECT_THAT(ReadVector(), ElementsAre("this", "is", "a", "test")); } TEST_F(LuaUtilsTest, HandlesRepeatedNestedFlatbufferFields) { // Create test flatbuffer. std::unique_ptr buffer = flatbuffer_builder_.NewRoot(); RepeatedField* repeated_field = buffer->Repeated("repeated_nested_field"); repeated_field->Add()->Set("string_field", "hello"); repeated_field->Add()->Set("string_field", "my"); ReflectiveFlatbuffer* nested = repeated_field->Add(); nested->Set("string_field", "old"); RepeatedField* nested_repeated = nested->Repeated("repeated_string_field"); nested_repeated->Add("friend"); nested_repeated->Add("how"); nested_repeated->Add("are"); repeated_field->Add()->Set("string_field", "you?"); const std::string serialized_buffer = buffer->Serialize(); PushFlatbuffer(schema_, flatbuffers::GetRoot( serialized_buffer.data())); lua_setglobal(state_, "arg"); RunScript(R"lua( result = {} for _, nested in pairs(arg.repeated_nested_field) do result[#result + 1] = nested.string_field for _, nested_string in pairs(nested.repeated_string_field) do result[#result + 1] = nested_string end end return result )lua"); EXPECT_THAT( ReadVector(), ElementsAre("hello", "my", "old", "friend", "how", "are", "you?")); } TEST_F(LuaUtilsTest, CorrectlyReadsTwoFlatbuffersSimultaneously) { // The first flatbuffer. std::unique_ptr buffer = flatbuffer_builder_.NewRoot(); buffer->Set("string_field", "first"); const std::string serialized_buffer = buffer->Serialize(); PushFlatbuffer(schema_, flatbuffers::GetRoot( serialized_buffer.data())); lua_setglobal(state_, "arg"); // The second flatbuffer. std::unique_ptr buffer2 = flatbuffer_builder_.NewRoot(); buffer2->Set("string_field", "second"); const std::string serialized_buffer2 = buffer2->Serialize(); PushFlatbuffer(schema_, flatbuffers::GetRoot( serialized_buffer2.data())); lua_setglobal(state_, "arg2"); RunScript(R"lua( return {arg.string_field, arg2.string_field} )lua"); EXPECT_THAT(ReadVector(), ElementsAre("first", "second")); } } // namespace } // namespace libtextclassifier3