-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathSaveModelToFile.lua
More file actions
90 lines (79 loc) · 2.86 KB
/
SaveModelToFile.lua
File metadata and controls
90 lines (79 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
------------------------------------------------------------------------
--[[ SaveModelToFile ]]--
-- Strategy. Not an Observer.
-- Saves version of the subject with the lowest error
------------------------------------------------------------------------
local SaveModelToFile = torch.class("dp.SaveModelToFile")
SaveModelToFile.isSaveModelToFile = true
function SaveModelToFile:__init(config)
config = config or {}
assert(not config[1], "Constructor requires key-value arguments")
local args, in_memory, save_dir, verbose = xlua.unpack(
{config},
'SaveModelToFile',
'Saves version of the subject with the lowest error',
{arg='in_memory', type='boolean', default=false,
help='only saves the subject to file at the end of the experiment'},
{arg='save_dir', type='string', help='defaults to dp.SAVE_DIR'},
{arg='verbose', type='boolean', default=true,
help='can print messages to stdout'}
)
self._in_memory = in_memory
self._save_dir = save_dir or dp.SAVE_DIR
self._verbose = verbose
end
function SaveModelToFile:setup(subject, mediator)
self._mediator = mediator
if self._in_memory then
self._mediator:subscribe('doneExperiment', self, 'doneExperiment')
end
--concatenate save directory with subject id
self._filename = paths.concat(self._save_dir, subject:id():toPath() .. '.dat')
os.execute('mkdir -p ' .. sys.dirname(self._filename))
end
function SaveModelToFile:filename()
return self._filename
end
function SaveModelToFile:save(subject)
assert(subject, "SaveModelToFile not setup error")
assert(subject._model, "subject does not have member _model")
if self._in_memory then
dp.vprint(self._verbose, 'SaveModelToFile: serializing subject to memory')
self._save_cache = nil
self._save_cache = torch.serialize(subject._model)
else
dp.vprint(self._verbose, 'SaveModelToFile: saving to '.. self._filename)
return torch.save(self._filename, subject._model:clone():forget():float())
end
end
function SaveModelToFile:doneExperiment()
if self._in_memory and self._save_cache then
dp.vprint(self._verbose, 'SaveModelToFile: saving to '.. self._filename)
local f = io.open(self._filename, 'w')
f:write(self._save_cache)
f:close()
end
end
-- the following are called by torch.File during [un]serialization
function SaveModelToFile:write(file)
-- prevent subject from being serialized twice
local state = _.map(self,
function(k,v)
if k ~= '_save_cache' then
return v;
end
end)
file:writeObject(state)
end
function SaveModelToFile:read(file)
local state = file:readObject()
for k,v in pairs(state) do
self[k] = v
end
end
function SaveModelToFile:verbose(verbose)
self._verbose = (verbose == nil) and true or verbose
end
function SaveModelToFile:silent()
self:verbose(false)
end