If you look though the files of the "inference" library, you may come across a function called download
in a class InferenceModel
. This function is intended for downloading files, but not in the ONNX format. However, the API key does not seem to work for it and throws an error.
def download(self, format="pt", location="."):
"""
Download the weights associated with a model.
Args:
format (str): The format of the output.
- 'pt': returns a PyTorch weights file
location (str): The location to save the weights file to
"""
supported_formats = ["pt"]
if format not in supported_formats:
raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}")
workspace, project, version = self.id.rsplit("/")
# get pt url
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
r.raise_for_status()
pt_weights_url = r.json()["weightsUrl"]
response = requests.get(pt_weights_url, stream=True)
# write the zip file to the desired location
with open(location + "/weights.pt", "wb") as f:
total_length = int(response.headers.get("content-length")) # type: ignore[arg-type]
for chunk in tqdm(
response.iter_content(chunk_size=1024),
desc=f"Downloading weights to {location}/weights.pt",
total=int(total_length / 1024) + 1,
):
if chunk:
f.write(chunk)
f.flush()
return
Here is server answer
{
"error": "Not authorized to download this model in pt format."
}
Maybe somebody have ideas how to do it?