Conversation
ed186ad to
c8fe804
Compare
|
Why not just get the weights via python dict since we're loading weights into memory anyway? then it would also work for a Pytorch state dict and you could share the weights with a pytorch that also loads the same model. The advantage of this API would be if we DMA or stream to device in future to avoid using RAM (which is probably want we want). |
Cargo.toml
Outdated
| rustnn = { git = "https://github.com/rustnn/rustnn", branch = "main" } | ||
| serde_json = "1.0" | ||
| webnn-graph = { git = "https://github.com/rustnn/webnn-graph", branch = "main" } | ||
| safetensors = "0.4" |
There was a problem hiding this comment.
AI keeps adding old versions. Most recent version is safetensors = "0.7.0"
There was a problem hiding this comment.
@gedoensmax added this feature to share weights between pytorch and webnn. It is an additional feature, not the only exclusive way to load weights in a fast manner.
There was a problem hiding this comment.
Fully AI written - just told it to let me provide a safetensor file and read the bytes out of that. I have not looked at how this works at all and focused on the pyhon pieces.
There was a problem hiding this comment.
but what do you mean by sharing? Do you want
- to have a safetensors file and then either read by pytorch or RustNN
- or do you want to load safetensors and then use the weights both by pytorch and RustNN in the same execution? for the latter I would just use Python safetensors and accept here a weight dict similar to pytorch
There was a problem hiding this comment.
The first option. I want to be able to load the weights into webNN that are already present in safetensors for pytorch. The idea was to let rust control how to load the weights from dist instead of loading them in python and passing them as dict.
There was a problem hiding this comment.
@mtavenrath to make the change to 0.7 since i can not edit this branch on the repo.
theHamsta
left a comment
There was a problem hiding this comment.
Approved when the newest safetensors = "0.7" is used
Being able to load SafeTensors allows to share weights between WebNN and Torch to allow file duplication.