OLD | NEW |
1 // Copyright 2017 The Chromium Authors. All rights reserved. | 1 // Copyright 2017 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "components/translate/core/browser/ranker_model_loader.h" | 5 #include "components/machine_intelligence/ranker_model_loader.h" |
6 | 6 |
7 #include <utility> | 7 #include <utility> |
8 | 8 |
9 #include "base/bind.h" | 9 #include "base/bind.h" |
10 #include "base/bind_helpers.h" | 10 #include "base/bind_helpers.h" |
11 #include "base/command_line.h" | 11 #include "base/command_line.h" |
12 #include "base/files/file_util.h" | 12 #include "base/files/file_util.h" |
13 #include "base/files/important_file_writer.h" | 13 #include "base/files/important_file_writer.h" |
14 #include "base/macros.h" | 14 #include "base/macros.h" |
15 #include "base/memory/ptr_util.h" | 15 #include "base/memory/ptr_util.h" |
16 #include "base/metrics/histogram_macros.h" | 16 #include "base/metrics/histogram_macros.h" |
17 #include "base/profiler/scoped_tracker.h" | 17 #include "base/profiler/scoped_tracker.h" |
18 #include "base/sequenced_task_runner.h" | 18 #include "base/sequenced_task_runner.h" |
19 #include "base/strings/string_util.h" | 19 #include "base/strings/string_util.h" |
20 #include "base/task_runner_util.h" | 20 #include "base/task_runner_util.h" |
21 #include "base/task_scheduler/post_task.h" | 21 #include "base/task_scheduler/post_task.h" |
22 #include "base/threading/sequenced_task_runner_handle.h" | 22 #include "base/threading/sequenced_task_runner_handle.h" |
23 #include "components/translate/core/browser/proto/ranker_model.pb.h" | 23 #include "components/machine_intelligence/proto/ranker_model.pb.h" |
24 #include "components/translate/core/browser/ranker_model.h" | 24 #include "components/machine_intelligence/ranker_model.h" |
25 #include "components/translate/core/browser/translate_url_fetcher.h" | 25 #include "components/machine_intelligence/ranker_url_fetcher.h" |
26 | 26 |
27 namespace translate { | 27 namespace machine_intelligence { |
28 namespace { | 28 namespace { |
29 | 29 |
30 using chrome_intelligence::RankerModel; | |
31 using chrome_intelligence::RankerModelProto; | |
32 | |
33 constexpr int kUrlFetcherId = 2; | |
34 | |
35 // The minimum duration, in minutes, between download attempts. | 30 // The minimum duration, in minutes, between download attempts. |
36 constexpr int kMinRetryDelayMins = 3; | 31 constexpr int kMinRetryDelayMins = 3; |
37 | 32 |
38 // Suffixes for the various histograms produced by the backend. | 33 // Suffixes for the various histograms produced by the backend. |
39 const char kWriteTimerHistogram[] = ".Timer.WriteModel"; | 34 const char kWriteTimerHistogram[] = ".Timer.WriteModel"; |
40 const char kReadTimerHistogram[] = ".Timer.ReadModel"; | 35 const char kReadTimerHistogram[] = ".Timer.ReadModel"; |
41 const char kDownloadTimerHistogram[] = ".Timer.DownloadModel"; | 36 const char kDownloadTimerHistogram[] = ".Timer.DownloadModel"; |
42 const char kParsetimerHistogram[] = ".Timer.ParseModel"; | 37 const char kParsetimerHistogram[] = ".Timer.ParseModel"; |
43 const char kModelStatusHistogram[] = ".Model.Status"; | 38 const char kModelStatusHistogram[] = ".Model.Status"; |
44 | 39 |
(...skipping 44 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
89 << model_path.value() << "'."; | 84 << model_path.value() << "'."; |
90 MyScopedHistogramTimer timer(uma_prefix + kWriteTimerHistogram); | 85 MyScopedHistogramTimer timer(uma_prefix + kWriteTimerHistogram); |
91 base::ImportantFileWriter::WriteFileAtomically(model_path, model_data); | 86 base::ImportantFileWriter::WriteFileAtomically(model_path, model_data); |
92 } | 87 } |
93 | 88 |
94 } // namespace | 89 } // namespace |
95 | 90 |
96 RankerModelLoader::RankerModelLoader( | 91 RankerModelLoader::RankerModelLoader( |
97 ValidateModelCallback validate_model_cb, | 92 ValidateModelCallback validate_model_cb, |
98 OnModelAvailableCallback on_model_available_cb, | 93 OnModelAvailableCallback on_model_available_cb, |
| 94 net::URLRequestContextGetter* request_context_getter, |
99 base::FilePath model_path, | 95 base::FilePath model_path, |
100 GURL model_url, | 96 GURL model_url, |
101 std::string uma_prefix) | 97 std::string uma_prefix) |
102 : background_task_runner_(base::CreateSequencedTaskRunnerWithTraits( | 98 : background_task_runner_(base::CreateSequencedTaskRunnerWithTraits( |
103 {base::MayBlock(), base::TaskPriority::BACKGROUND, | 99 {base::MayBlock(), base::TaskPriority::BACKGROUND, |
104 base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})), | 100 base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})), |
105 validate_model_cb_(std::move(validate_model_cb)), | 101 validate_model_cb_(std::move(validate_model_cb)), |
106 on_model_available_cb_(std::move(on_model_available_cb)), | 102 on_model_available_cb_(std::move(on_model_available_cb)), |
| 103 request_context_getter_(request_context_getter), |
107 model_path_(std::move(model_path)), | 104 model_path_(std::move(model_path)), |
108 model_url_(std::move(model_url)), | 105 model_url_(std::move(model_url)), |
109 uma_prefix_(std::move(uma_prefix)), | 106 uma_prefix_(std::move(uma_prefix)), |
110 url_fetcher_(base::MakeUnique<TranslateURLFetcher>(kUrlFetcherId)), | 107 url_fetcher_(base::MakeUnique<RankerURLFetcher>()), |
111 weak_ptr_factory_(this) {} | 108 weak_ptr_factory_(this) {} |
112 | 109 |
113 RankerModelLoader::~RankerModelLoader() { | 110 RankerModelLoader::~RankerModelLoader() { |
114 DCHECK(sequence_checker_.CalledOnValidSequence()); | 111 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
115 } | 112 } |
116 | 113 |
117 void RankerModelLoader::NotifyOfRankerActivity() { | 114 void RankerModelLoader::NotifyOfRankerActivity() { |
118 DCHECK(sequence_checker_.CalledOnValidSequence()); | 115 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
119 switch (state_) { | 116 switch (state_) { |
120 case LoaderState::NOT_STARTED: | 117 case LoaderState::NOT_STARTED: |
121 if (!model_path_.empty()) { | 118 if (!model_path_.empty()) { |
122 StartLoadFromFile(); | 119 StartLoadFromFile(); |
123 break; | 120 break; |
124 } | 121 } |
125 // There was no configured model path. Switch the state to IDLE and | 122 // There was no configured model path. Switch the state to IDLE and |
126 // fall through to consider the URL. | 123 // fall through to consider the URL. |
127 state_ = LoaderState::IDLE; | 124 state_ = LoaderState::IDLE; |
128 case LoaderState::IDLE: | 125 case LoaderState::IDLE: |
129 if (model_url_.is_valid()) { | 126 if (model_url_.is_valid()) { |
130 StartLoadFromURL(); | 127 StartLoadFromURL(); |
131 break; | 128 break; |
132 } | 129 } |
133 // There was no configured model URL. Switch the state to FINISHED and | 130 // There was no configured model URL. Switch the state to FINISHED and |
134 // fall through. | 131 // fall through. |
135 state_ = LoaderState::FINISHED; | 132 state_ = LoaderState::FINISHED; |
136 case LoaderState::FINISHED: | 133 case LoaderState::FINISHED: |
137 case LoaderState::LOADING_FROM_FILE: | 134 case LoaderState::LOADING_FROM_FILE: |
138 case LoaderState::LOADING_FROM_URL: | 135 case LoaderState::LOADING_FROM_URL: |
139 // Nothing to do. | 136 // Nothing to do. |
140 break; | 137 break; |
141 } | 138 } |
142 } | 139 } |
143 | 140 |
144 void RankerModelLoader::StartLoadFromFile() { | 141 void RankerModelLoader::StartLoadFromFile() { |
145 DCHECK(sequence_checker_.CalledOnValidSequence()); | 142 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
146 DCHECK_EQ(state_, LoaderState::NOT_STARTED); | 143 DCHECK_EQ(state_, LoaderState::NOT_STARTED); |
147 DCHECK(!model_path_.empty()); | 144 DCHECK(!model_path_.empty()); |
148 state_ = LoaderState::LOADING_FROM_FILE; | 145 state_ = LoaderState::LOADING_FROM_FILE; |
149 load_start_time_ = base::TimeTicks::Now(); | 146 load_start_time_ = base::TimeTicks::Now(); |
150 base::PostTaskAndReplyWithResult(background_task_runner_.get(), FROM_HERE, | 147 base::PostTaskAndReplyWithResult(background_task_runner_.get(), FROM_HERE, |
151 base::Bind(&LoadFromFile, model_path_), | 148 base::Bind(&LoadFromFile, model_path_), |
152 base::Bind(&RankerModelLoader::OnFileLoaded, | 149 base::Bind(&RankerModelLoader::OnFileLoaded, |
153 weak_ptr_factory_.GetWeakPtr())); | 150 weak_ptr_factory_.GetWeakPtr())); |
154 } | 151 } |
155 | 152 |
156 void RankerModelLoader::OnFileLoaded(const std::string& data) { | 153 void RankerModelLoader::OnFileLoaded(const std::string& data) { |
157 DCHECK(sequence_checker_.CalledOnValidSequence()); | 154 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
158 DCHECK_EQ(state_, LoaderState::LOADING_FROM_FILE); | 155 DCHECK_EQ(state_, LoaderState::LOADING_FROM_FILE); |
159 | 156 |
160 // Record the duration of the download. | 157 // Record the duration of the download. |
161 RecordTimerHistogram(uma_prefix_ + kReadTimerHistogram, | 158 RecordTimerHistogram(uma_prefix_ + kReadTimerHistogram, |
162 base::TimeTicks::Now() - load_start_time_); | 159 base::TimeTicks::Now() - load_start_time_); |
163 | 160 |
164 // Empty data means |model_path| wasn't successfully read. Otherwise, | 161 // Empty data means |model_path| wasn't successfully read. Otherwise, |
165 // parse and validate the model. | 162 // parse and validate the model. |
166 std::unique_ptr<RankerModel> model; | 163 std::unique_ptr<RankerModel> model; |
167 if (data.empty()) { | 164 if (data.empty()) { |
(...skipping 26 matching lines...) Expand all Loading... |
194 on_model_available_cb_.Run(std::move(model)); | 191 on_model_available_cb_.Run(std::move(model)); |
195 } | 192 } |
196 | 193 |
197 // Notify the state machine. This will immediately kick off a download if | 194 // Notify the state machine. This will immediately kick off a download if |
198 // one is required, instead of waiting for the next organic detection of | 195 // one is required, instead of waiting for the next organic detection of |
199 // ranker activity. | 196 // ranker activity. |
200 NotifyOfRankerActivity(); | 197 NotifyOfRankerActivity(); |
201 } | 198 } |
202 | 199 |
203 void RankerModelLoader::StartLoadFromURL() { | 200 void RankerModelLoader::StartLoadFromURL() { |
204 DCHECK(sequence_checker_.CalledOnValidSequence()); | 201 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
205 DCHECK_EQ(state_, LoaderState::IDLE); | 202 DCHECK_EQ(state_, LoaderState::IDLE); |
206 DCHECK(model_url_.is_valid()); | 203 DCHECK(model_url_.is_valid()); |
207 | 204 |
208 // Do nothing if the download attempts should be throttled. | 205 // Do nothing if the download attempts should be throttled. |
209 if (base::TimeTicks::Now() < next_earliest_download_time_) { | 206 if (base::TimeTicks::Now() < next_earliest_download_time_) { |
210 DVLOG(2) << "Last download attempt was too recent."; | 207 DVLOG(2) << "Last download attempt was too recent."; |
211 ReportModelStatus(RankerModelStatus::DOWNLOAD_THROTTLED); | 208 ReportModelStatus(RankerModelStatus::DOWNLOAD_THROTTLED); |
212 return; | 209 return; |
213 } | 210 } |
214 | 211 |
215 // Kick off the next download attempt and reset the time of the next earliest | 212 // Kick off the next download attempt and reset the time of the next earliest |
216 // allowable download attempt. | 213 // allowable download attempt. |
217 DVLOG(2) << "Downloading model from: " << model_url_; | 214 DVLOG(2) << "Downloading model from: " << model_url_; |
218 state_ = LoaderState::LOADING_FROM_URL; | 215 state_ = LoaderState::LOADING_FROM_URL; |
219 load_start_time_ = base::TimeTicks::Now(); | 216 load_start_time_ = base::TimeTicks::Now(); |
220 next_earliest_download_time_ = | 217 next_earliest_download_time_ = |
221 load_start_time_ + base::TimeDelta::FromMinutes(kMinRetryDelayMins); | 218 load_start_time_ + base::TimeDelta::FromMinutes(kMinRetryDelayMins); |
222 bool request_started = url_fetcher_->Request( | 219 bool request_started = |
223 model_url_, base::Bind(&RankerModelLoader::OnURLFetched, | 220 url_fetcher_->Request(model_url_, |
224 weak_ptr_factory_.GetWeakPtr())); | 221 base::Bind(&RankerModelLoader::OnURLFetched, |
| 222 weak_ptr_factory_.GetWeakPtr()), |
| 223 request_context_getter_.get()); |
225 | 224 |
226 // |url_fetcher_| maintains a request retry counter. If all allowed attempts | 225 // |url_fetcher_| maintains a request retry counter. If all allowed attempts |
227 // have already been exhausted, then the loader is finished and has abandoned | 226 // have already been exhausted, then the loader is finished and has abandoned |
228 // loading the model. | 227 // loading the model. |
229 if (!request_started) { | 228 if (!request_started) { |
230 DVLOG(2) << "Model download abandoned."; | 229 DVLOG(2) << "Model download abandoned."; |
231 ReportModelStatus(RankerModelStatus::MODEL_LOADING_ABANDONED); | 230 ReportModelStatus(RankerModelStatus::MODEL_LOADING_ABANDONED); |
232 state_ = LoaderState::FINISHED; | 231 state_ = LoaderState::FINISHED; |
233 } | 232 } |
234 } | 233 } |
235 | 234 |
236 void RankerModelLoader::OnURLFetched(int /* id */, | 235 void RankerModelLoader::OnURLFetched(bool success, const std::string& data) { |
237 bool success, | 236 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
238 const std::string& data) { | |
239 DCHECK(sequence_checker_.CalledOnValidSequence()); | |
240 DCHECK_EQ(state_, LoaderState::LOADING_FROM_URL); | 237 DCHECK_EQ(state_, LoaderState::LOADING_FROM_URL); |
241 | 238 |
242 // Record the duration of the download. | 239 // Record the duration of the download. |
243 RecordTimerHistogram(uma_prefix_ + kDownloadTimerHistogram, | 240 RecordTimerHistogram(uma_prefix_ + kDownloadTimerHistogram, |
244 base::TimeTicks::Now() - load_start_time_); | 241 base::TimeTicks::Now() - load_start_time_); |
245 | 242 |
246 // On request failure, transition back to IDLE. The loader will retry, or | 243 // On request failure, transition back to IDLE. The loader will retry, or |
247 // enforce the max download attempts, later. | 244 // enforce the max download attempts, later. |
248 if (!success || data.empty()) { | 245 if (!success || data.empty()) { |
249 DVLOG(2) << "Download from '" << model_url_ << "'' failed."; | 246 DVLOG(2) << "Download from '" << model_url_ << "'' failed."; |
(...skipping 23 matching lines...) Expand all Loading... |
273 background_task_runner_->PostTask( | 270 background_task_runner_->PostTask( |
274 FROM_HERE, base::BindOnce(&SaveToFile, model_url_, model_path_, | 271 FROM_HERE, base::BindOnce(&SaveToFile, model_url_, model_path_, |
275 model->SerializeAsString(), uma_prefix_)); | 272 model->SerializeAsString(), uma_prefix_)); |
276 } | 273 } |
277 | 274 |
278 // The loader is finished. Transfer the model to the client. | 275 // The loader is finished. Transfer the model to the client. |
279 state_ = LoaderState::FINISHED; | 276 state_ = LoaderState::FINISHED; |
280 on_model_available_cb_.Run(std::move(model)); | 277 on_model_available_cb_.Run(std::move(model)); |
281 } | 278 } |
282 | 279 |
283 std::unique_ptr<chrome_intelligence::RankerModel> | 280 std::unique_ptr<RankerModel> RankerModelLoader::CreateAndValidateModel( |
284 RankerModelLoader::CreateAndValidateModel(const std::string& data) { | 281 const std::string& data) { |
285 DCHECK(sequence_checker_.CalledOnValidSequence()); | 282 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
286 MyScopedHistogramTimer timer(uma_prefix_ + kParsetimerHistogram); | 283 MyScopedHistogramTimer timer(uma_prefix_ + kParsetimerHistogram); |
287 auto model = RankerModel::FromString(data); | 284 auto model = RankerModel::FromString(data); |
288 if (ReportModelStatus(model ? validate_model_cb_.Run(*model) | 285 if (ReportModelStatus(model ? validate_model_cb_.Run(*model) |
289 : RankerModelStatus::PARSE_FAILED) != | 286 : RankerModelStatus::PARSE_FAILED) != |
290 RankerModelStatus::OK) { | 287 RankerModelStatus::OK) { |
291 return nullptr; | 288 return nullptr; |
292 } | 289 } |
293 return model; | 290 return model; |
294 } | 291 } |
295 | 292 |
296 RankerModelStatus RankerModelLoader::ReportModelStatus( | 293 RankerModelStatus RankerModelLoader::ReportModelStatus( |
297 RankerModelStatus model_status) { | 294 RankerModelStatus model_status) { |
298 DCHECK(sequence_checker_.CalledOnValidSequence()); | 295 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
299 base::HistogramBase* histogram = base::LinearHistogram::FactoryGet( | 296 base::HistogramBase* histogram = base::LinearHistogram::FactoryGet( |
300 uma_prefix_ + kModelStatusHistogram, 1, | 297 uma_prefix_ + kModelStatusHistogram, 1, |
301 static_cast<int>(RankerModelStatus::MAX), | 298 static_cast<int>(RankerModelStatus::MAX), |
302 static_cast<int>(RankerModelStatus::MAX) + 1, | 299 static_cast<int>(RankerModelStatus::MAX) + 1, |
303 base::HistogramBase::kUmaTargetedHistogramFlag); | 300 base::HistogramBase::kUmaTargetedHistogramFlag); |
304 if (histogram) | 301 if (histogram) |
305 histogram->Add(static_cast<int>(model_status)); | 302 histogram->Add(static_cast<int>(model_status)); |
306 return model_status; | 303 return model_status; |
307 } | 304 } |
308 | 305 |
309 } // namespace translate | 306 } // namespace machine_intelligence |
OLD | NEW |