OLD | NEW |
| (Empty) |
1 // Copyright 2017 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "components/translate/core/browser/ranker_model_loader.h" | |
6 | |
7 #include <deque> | |
8 #include <initializer_list> | |
9 #include <memory> | |
10 #include <vector> | |
11 | |
12 #include "base/files/file_util.h" | |
13 #include "base/files/scoped_temp_dir.h" | |
14 #include "base/memory/ptr_util.h" | |
15 #include "base/memory/ref_counted.h" | |
16 #include "base/message_loop/message_loop.h" | |
17 #include "base/run_loop.h" | |
18 #include "base/strings/stringprintf.h" | |
19 #include "base/task_scheduler/post_task.h" | |
20 #include "base/task_scheduler/task_scheduler.h" | |
21 #include "base/test/scoped_feature_list.h" | |
22 #include "base/test/scoped_task_scheduler.h" | |
23 #include "base/test/test_simple_task_runner.h" | |
24 #include "base/threading/thread_task_runner_handle.h" | |
25 #include "components/translate/core/browser/proto/ranker_model.pb.h" | |
26 #include "components/translate/core/browser/proto/translate_ranker_model.pb.h" | |
27 #include "components/translate/core/browser/ranker_model.h" | |
28 #include "components/translate/core/browser/translate_download_manager.h" | |
29 #include "net/url_request/test_url_fetcher_factory.h" | |
30 #include "net/url_request/url_request_test_util.h" | |
31 #include "testing/gtest/include/gtest/gtest.h" | |
32 | |
33 namespace { | |
34 | |
35 using base::TaskScheduler; | |
36 using chrome_intelligence::RankerModel; | |
37 using translate::RankerModelLoader; | |
38 using translate::RankerModelStatus; | |
39 using translate::TranslateDownloadManager; | |
40 | |
41 const char kInvalidModelData[] = "not a valid model"; | |
42 const int kInvalidModelSize = sizeof(kInvalidModelData) - 1; | |
43 | |
44 class RankerModelLoaderTest : public ::testing::Test { | |
45 protected: | |
46 RankerModelLoaderTest(); | |
47 | |
48 void SetUp() override; | |
49 | |
50 void TearDown() override; | |
51 | |
52 // Returns a copy of |model|. | |
53 static std::unique_ptr<RankerModel> Clone(const RankerModel& model); | |
54 | |
55 // Returns true if |m1| and |m2| are identical. | |
56 static bool IsEqual(const RankerModel& m1, const RankerModel& m2); | |
57 | |
58 // Returns true if |m1| and |m2| are identical modulo metadata. | |
59 static bool IsEquivalent(const RankerModel& m1, const RankerModel& m2); | |
60 | |
61 // Helper method to drive the loader for |model_path| and |model_url|. | |
62 bool DoLoaderTest(const base::FilePath& model_path, const GURL& model_url); | |
63 | |
64 // Initialize the "remote" model data used for testing. | |
65 void InitRemoteModels(); | |
66 | |
67 // Initialize the "local" model data used for testing. | |
68 void InitLocalModels(); | |
69 | |
70 // Helper method used by InitRemoteModels() and InitLocalModels(). | |
71 void InitModel(const GURL& model_url, | |
72 const base::Time& last_modified, | |
73 const base::TimeDelta& cache_duration, | |
74 RankerModel* model); | |
75 | |
76 // Save |model| to |model_path|. Used by InitRemoteModels() and | |
77 // InitLocalModels() | |
78 void SaveModel(const RankerModel& model, const base::FilePath& model_path); | |
79 | |
80 // Implements RankerModelLoader's ValidateModelCallback interface. | |
81 RankerModelStatus ValidateModel(const RankerModel& model); | |
82 | |
83 // Implements RankerModelLoader's OnModelAvailableCallback interface. | |
84 void OnModelAvailable(std::unique_ptr<RankerModel> model); | |
85 | |
86 // Sets up the task scheduling/task-runner environment for each test. | |
87 base::test::ScopedTaskScheduler scoped_task_scheduler_; | |
88 | |
89 // Override the default URL fetcher to return custom responses for tests. | |
90 net::FakeURLFetcherFactory url_fetcher_factory_; | |
91 | |
92 // Temporary directory for model files. | |
93 base::ScopedTempDir scoped_temp_dir_; | |
94 | |
95 // Cache and reset the application locale for each test. | |
96 std::string locale_; | |
97 | |
98 // Used to initialize the translate download manager. | |
99 scoped_refptr<net::TestURLRequestContextGetter> request_context_; | |
100 | |
101 // A queue of responses to return from Validate(). If empty, validate will | |
102 // return 'OK'. | |
103 std::deque<RankerModelStatus> validate_model_response_; | |
104 | |
105 // A cached to remember the model validation calls. | |
106 std::vector<std::unique_ptr<RankerModel>> validated_models_; | |
107 | |
108 // A cache to remember the OnModelAvailable calls. | |
109 std::vector<std::unique_ptr<RankerModel>> available_models_; | |
110 | |
111 // Cached model file paths. | |
112 base::FilePath local_model_path_; | |
113 base::FilePath expired_model_path_; | |
114 base::FilePath invalid_model_path_; | |
115 | |
116 // Model URLS. | |
117 GURL remote_model_url_; | |
118 GURL invalid_model_url_; | |
119 GURL failed_model_url_; | |
120 | |
121 // Model Data. | |
122 RankerModel remote_model_; | |
123 RankerModel local_model_; | |
124 RankerModel expired_model_; | |
125 | |
126 private: | |
127 DISALLOW_COPY_AND_ASSIGN(RankerModelLoaderTest); | |
128 }; | |
129 | |
130 RankerModelLoaderTest::RankerModelLoaderTest() | |
131 : url_fetcher_factory_(nullptr) {} | |
132 | |
133 void RankerModelLoaderTest::SetUp() { | |
134 // Setup the translate download manager. | |
135 locale_ = TranslateDownloadManager::GetInstance()->application_locale(); | |
136 request_context_ = | |
137 new net::TestURLRequestContextGetter(base::ThreadTaskRunnerHandle::Get()); | |
138 TranslateDownloadManager::GetInstance()->set_application_locale("fr-CA"); | |
139 TranslateDownloadManager::GetInstance()->set_request_context( | |
140 request_context_.get()); | |
141 | |
142 ASSERT_TRUE(scoped_temp_dir_.CreateUniqueTempDir()); | |
143 const auto& temp_dir_path = scoped_temp_dir_.GetPath(); | |
144 | |
145 // Setup the model file paths. | |
146 local_model_path_ = temp_dir_path.AppendASCII("local_model.bin"); | |
147 expired_model_path_ = temp_dir_path.AppendASCII("expired_model.bin"); | |
148 invalid_model_path_ = temp_dir_path.AppendASCII("invalid_model.bin"); | |
149 | |
150 // Setup the model URLs. | |
151 remote_model_url_ = GURL("https://some.url.net/good.model.bin"); | |
152 invalid_model_url_ = GURL("https://some.url.net/bad.model.bin"); | |
153 failed_model_url_ = GURL("https://some.url.net/fail"); | |
154 | |
155 // Initialize the model data. | |
156 ASSERT_NO_FATAL_FAILURE(InitRemoteModels()); | |
157 ASSERT_NO_FATAL_FAILURE(InitLocalModels()); | |
158 } | |
159 | |
160 void RankerModelLoaderTest::TearDown() { | |
161 base::RunLoop().RunUntilIdle(); | |
162 TranslateDownloadManager::GetInstance()->set_application_locale(locale_); | |
163 TranslateDownloadManager::GetInstance()->set_request_context(nullptr); | |
164 } | |
165 | |
166 // static | |
167 std::unique_ptr<RankerModel> RankerModelLoaderTest::Clone( | |
168 const RankerModel& model) { | |
169 auto copy = base::MakeUnique<RankerModel>(); | |
170 *copy->mutable_proto() = model.proto(); | |
171 return copy; | |
172 } | |
173 | |
174 // static | |
175 bool RankerModelLoaderTest::IsEqual(const RankerModel& m1, | |
176 const RankerModel& m2) { | |
177 return m1.SerializeAsString() == m2.SerializeAsString(); | |
178 } | |
179 | |
180 // static | |
181 bool RankerModelLoaderTest::IsEquivalent(const RankerModel& m1, | |
182 const RankerModel& m2) { | |
183 auto copy_m1 = Clone(m1); | |
184 copy_m1->mutable_proto()->mutable_metadata()->Clear(); | |
185 | |
186 auto copy_m2 = Clone(m2); | |
187 copy_m2->mutable_proto()->mutable_metadata()->Clear(); | |
188 | |
189 return IsEqual(*copy_m1, *copy_m2); | |
190 } | |
191 | |
192 bool RankerModelLoaderTest::DoLoaderTest(const base::FilePath& model_path, | |
193 const GURL& model_url) { | |
194 auto loader = base::MakeUnique<RankerModelLoader>( | |
195 base::Bind(&RankerModelLoaderTest::ValidateModel, base::Unretained(this)), | |
196 base::Bind(&RankerModelLoaderTest::OnModelAvailable, | |
197 base::Unretained(this)), | |
198 model_path, model_url, "RankerModelLoaderTest"); | |
199 loader->NotifyOfRankerActivity(); | |
200 base::RunLoop().RunUntilIdle(); | |
201 | |
202 return true; | |
203 } | |
204 | |
205 void RankerModelLoaderTest::InitRemoteModels() { | |
206 InitModel(remote_model_url_, base::Time(), base::TimeDelta(), &remote_model_); | |
207 url_fetcher_factory_.SetFakeResponse( | |
208 remote_model_url_, remote_model_.SerializeAsString(), net::HTTP_OK, | |
209 net::URLRequestStatus::SUCCESS); | |
210 url_fetcher_factory_.SetFakeResponse(invalid_model_url_, kInvalidModelData, | |
211 net::HTTP_OK, | |
212 net::URLRequestStatus::SUCCESS); | |
213 url_fetcher_factory_.SetFakeResponse(failed_model_url_, "", | |
214 net::HTTP_INTERNAL_SERVER_ERROR, | |
215 net::URLRequestStatus::FAILED); | |
216 } | |
217 | |
218 void RankerModelLoaderTest::InitLocalModels() { | |
219 InitModel(remote_model_url_, base::Time::Now(), base::TimeDelta::FromDays(30), | |
220 &local_model_); | |
221 InitModel(remote_model_url_, | |
222 base::Time::Now() - base::TimeDelta::FromDays(60), | |
223 base::TimeDelta::FromDays(30), &expired_model_); | |
224 SaveModel(local_model_, local_model_path_); | |
225 SaveModel(expired_model_, expired_model_path_); | |
226 ASSERT_EQ(base::WriteFile(invalid_model_path_, kInvalidModelData, | |
227 kInvalidModelSize), | |
228 kInvalidModelSize); | |
229 } | |
230 | |
231 void RankerModelLoaderTest::InitModel(const GURL& model_url, | |
232 const base::Time& last_modified, | |
233 const base::TimeDelta& cache_duration, | |
234 RankerModel* model) { | |
235 ASSERT_TRUE(model != nullptr); | |
236 model->mutable_proto()->Clear(); | |
237 | |
238 auto* metadata = model->mutable_proto()->mutable_metadata(); | |
239 if (!model_url.is_empty()) | |
240 metadata->set_source(model_url.spec()); | |
241 if (!last_modified.is_null()) { | |
242 auto last_modified_sec = (last_modified - base::Time()).InSeconds(); | |
243 metadata->set_last_modified_sec(last_modified_sec); | |
244 } | |
245 if (!cache_duration.is_zero()) | |
246 metadata->set_cache_duration_sec(cache_duration.InSeconds()); | |
247 | |
248 auto* translate = model->mutable_proto()->mutable_translate(); | |
249 translate->set_version(1); | |
250 | |
251 auto* logit = translate->mutable_logistic_regression_model(); | |
252 logit->set_bias(0.1f); | |
253 logit->set_accept_ratio_weight(0.2f); | |
254 logit->set_decline_ratio_weight(0.3f); | |
255 logit->set_ignore_ratio_weight(0.4f); | |
256 } | |
257 | |
258 void RankerModelLoaderTest::SaveModel(const RankerModel& model, | |
259 const base::FilePath& model_path) { | |
260 std::string model_str = model.SerializeAsString(); | |
261 ASSERT_EQ(base::WriteFile(model_path, model_str.data(), model_str.size()), | |
262 static_cast<int>(model_str.size())); | |
263 } | |
264 | |
265 RankerModelStatus RankerModelLoaderTest::ValidateModel( | |
266 const RankerModel& model) { | |
267 validated_models_.push_back(Clone(model)); | |
268 RankerModelStatus response = RankerModelStatus::OK; | |
269 if (!validate_model_response_.empty()) { | |
270 response = validate_model_response_.front(); | |
271 validate_model_response_.pop_front(); | |
272 } | |
273 return response; | |
274 } | |
275 | |
276 void RankerModelLoaderTest::OnModelAvailable( | |
277 std::unique_ptr<RankerModel> model) { | |
278 available_models_.push_back(std::move(model)); | |
279 } | |
280 | |
281 } // namespace | |
282 | |
283 TEST_F(RankerModelLoaderTest, NoLocalOrRemoteModel) { | |
284 ASSERT_TRUE(DoLoaderTest(base::FilePath(), GURL())); | |
285 | |
286 EXPECT_EQ(0U, validated_models_.size()); | |
287 EXPECT_EQ(0U, available_models_.size()); | |
288 } | |
289 | |
290 TEST_F(RankerModelLoaderTest, BadLocalAndRemoteModel) { | |
291 ASSERT_TRUE(DoLoaderTest(invalid_model_path_, invalid_model_url_)); | |
292 | |
293 EXPECT_EQ(0U, validated_models_.size()); | |
294 EXPECT_EQ(0U, available_models_.size()); | |
295 } | |
296 | |
297 TEST_F(RankerModelLoaderTest, LoadFromFileOnly) { | |
298 EXPECT_TRUE(DoLoaderTest(local_model_path_, GURL())); | |
299 | |
300 ASSERT_EQ(1U, validated_models_.size()); | |
301 ASSERT_EQ(1U, available_models_.size()); | |
302 EXPECT_TRUE(IsEqual(*validated_models_[0], local_model_)); | |
303 EXPECT_TRUE(IsEqual(*available_models_[0], local_model_)); | |
304 } | |
305 | |
306 TEST_F(RankerModelLoaderTest, LoadFromFileSkipsDownload) { | |
307 ASSERT_TRUE(DoLoaderTest(local_model_path_, remote_model_url_)); | |
308 | |
309 ASSERT_EQ(1U, validated_models_.size()); | |
310 ASSERT_EQ(1U, available_models_.size()); | |
311 EXPECT_TRUE(IsEqual(*validated_models_[0], local_model_)); | |
312 EXPECT_TRUE(IsEqual(*available_models_[0], local_model_)); | |
313 } | |
314 | |
315 TEST_F(RankerModelLoaderTest, LoadFromFileAndBadUrl) { | |
316 ASSERT_TRUE(DoLoaderTest(local_model_path_, invalid_model_url_)); | |
317 ASSERT_EQ(1U, validated_models_.size()); | |
318 ASSERT_EQ(1U, available_models_.size()); | |
319 EXPECT_TRUE(IsEqual(*validated_models_[0], local_model_)); | |
320 EXPECT_TRUE(IsEqual(*available_models_[0], local_model_)); | |
321 } | |
322 | |
323 TEST_F(RankerModelLoaderTest, LoadFromURLOnly) { | |
324 ASSERT_TRUE(DoLoaderTest(base::FilePath(), remote_model_url_)); | |
325 ASSERT_EQ(1U, validated_models_.size()); | |
326 ASSERT_EQ(1U, available_models_.size()); | |
327 EXPECT_TRUE(IsEquivalent(*validated_models_[0], remote_model_)); | |
328 EXPECT_TRUE(IsEquivalent(*available_models_[0], remote_model_)); | |
329 } | |
330 | |
331 TEST_F(RankerModelLoaderTest, LoadFromExpiredFileTriggersDownload) { | |
332 ASSERT_TRUE(DoLoaderTest(expired_model_path_, remote_model_url_)); | |
333 ASSERT_EQ(2U, validated_models_.size()); | |
334 ASSERT_EQ(2U, available_models_.size()); | |
335 EXPECT_TRUE(IsEquivalent(*validated_models_[0], local_model_)); | |
336 EXPECT_TRUE(IsEquivalent(*available_models_[0], local_model_)); | |
337 EXPECT_TRUE(IsEquivalent(*validated_models_[1], remote_model_)); | |
338 EXPECT_TRUE(IsEquivalent(*available_models_[1], remote_model_)); | |
339 } | |
340 | |
341 TEST_F(RankerModelLoaderTest, LoadFromBadFileTriggersDownload) { | |
342 ASSERT_TRUE(DoLoaderTest(invalid_model_path_, remote_model_url_)); | |
343 ASSERT_EQ(1U, validated_models_.size()); | |
344 ASSERT_EQ(1U, available_models_.size()); | |
345 EXPECT_TRUE(IsEquivalent(*validated_models_[0], remote_model_)); | |
346 EXPECT_TRUE(IsEquivalent(*available_models_[0], remote_model_)); | |
347 } | |
348 | |
349 TEST_F(RankerModelLoaderTest, IncompatibleCachedFileTriggersDownload) { | |
350 validate_model_response_.push_back(RankerModelStatus::INCOMPATIBLE); | |
351 | |
352 ASSERT_TRUE(DoLoaderTest(local_model_path_, remote_model_url_)); | |
353 ASSERT_EQ(2U, validated_models_.size()); | |
354 ASSERT_EQ(1U, available_models_.size()); | |
355 EXPECT_TRUE(IsEquivalent(*validated_models_[0], local_model_)); | |
356 EXPECT_TRUE(IsEquivalent(*validated_models_[1], remote_model_)); | |
357 EXPECT_TRUE(IsEquivalent(*available_models_[0], remote_model_)); | |
358 } | |
359 | |
360 TEST_F(RankerModelLoaderTest, IncompatibleDownloadedFileKeepsExpired) { | |
361 validate_model_response_.push_back(RankerModelStatus::OK); | |
362 validate_model_response_.push_back(RankerModelStatus::INCOMPATIBLE); | |
363 | |
364 ASSERT_TRUE(DoLoaderTest(expired_model_path_, remote_model_url_)); | |
365 ASSERT_EQ(2U, validated_models_.size()); | |
366 ASSERT_EQ(1U, available_models_.size()); | |
367 EXPECT_TRUE(IsEquivalent(*validated_models_[0], local_model_)); | |
368 EXPECT_TRUE(IsEquivalent(*validated_models_[1], remote_model_)); | |
369 EXPECT_TRUE(IsEquivalent(*available_models_[0], local_model_)); | |
370 } | |
OLD | NEW |