diff --git a/pwnagotchi/ai/train.py b/pwnagotchi/ai/train.py index 4ca26389..c9846ca6 100644 --- a/pwnagotchi/ai/train.py +++ b/pwnagotchi/ai/train.py @@ -77,9 +77,11 @@ class Stats(object): }) temp = "%s.tmp" % self.path + back = "%s.bak" % self.path with open(temp, 'wt') as fp: fp.write(data) + os.replace(self.path, back) os.replace(temp, self.path) @@ -113,7 +115,9 @@ class AsyncTrainer(object): def _save_ai(self): logging.info("[ai] saving model to %s ..." % self._nn_path) temp = "%s.tmp" % self._nn_path + back = "%s.bak" % self._nn_path self._model.save(temp) + os.replace(self._nn_path, back) os.replace(temp, self._nn_path) def on_ai_step(self):