这是最简单的修改方法。你只需要在代码的最开始设置环境变量,让
huggingface_hub 走国内的镜像站点(如
hf-mirror.com)。
操作步骤:
修改你的
train.py 文件,在所有
import 语句的最上方(甚至在
import timm 之前)添加以下两行代码:
import os
# 设置 Hugging Face 镜像地址
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 原有的 import...
import torch
import torch.nn as nn
# ...
或者,你也可以直接在运行命令的终端中设置环境变量:
export HF_ENDPOINT=https://hf-mirror.com
这样
timm 就会通过镜像站自动下载
swin_tiny_patch4_window7_224 的权重,速度通常很快。
如果镜像站也无法访问,或者你需要离线训练,可以手动下载权重文件,然后修改代码从本地加载。
操作步骤:
下载权重文件: 你需要下载
swin_tiny_patch4_window7_224 的
.pth 或
.bin 文件。
你可以在能上网的电脑上访问 Hugging Face timm/swin_tiny_patch4_window7_224.ms_in1k 下载
model.safetensors 或者
pytorch_model.bin。
或者搜索
swin_tiny_patch4_window7_224.ms_in1k.pth 下载。
上传到服务器: 将下载好的文件放到你的代码目录下,例如
/gemini/code/pretrained/swin_tiny_patch4_window7_224.ms_in1k.bin。
修改
model_acl.py 代码: 你需要修改
build_model_multiskin 函数,将
pretrained=True 改为
False,然后手动加载权重。
# 修改 model_acl.py 中的 build_model_multiskin 函数
def build_model_multiskin(args):
# 1. 关闭自动下载
encoder = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False)
# 2. 指定本地权重路径 (请确保路径正确)
checkpoint_path = '/gemini/code/pretrained/swin_tiny_patch4_window7_224.ms_in1k.bin'
if os.path.exists(checkpoint_path):
print(f"正在加载本地预训练权重: {checkpoint_path}")
# 加载权重
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# timm 的权重有时包裹在 'model' 或 'state_dict' 键中,有时直接是字典
# 这里做一个简单的兼容处理
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# 加载参数,strict=False 可以容忍微小的键名不匹配
encoder.load_state_dict(state_dict, strict=False)
else:
print(f"警告: 未找到本地权重文件 {checkpoint_path},将使用随机初始化!")
model = model_multiskin(encoder, num_classes=args.class_num, shared=args.shared, args=args)
return model
如果你在服务器上配置了科学上网代理,可以在代码中设置代理端口。
在
train.py 开头添加:
import os
# 将下面的地址替换为你的代理地址
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
首先尝试方案一(HF镜像)。这是改动最小、最通用的方法。你只需要在
train.py 的第一行加上:
import os; os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
下一步: 如果不确定应该加在哪里,你可以直接在命令行输入以下命令先测试一下是否能连通:
HF_ENDPOINT=https://hf-mirror.com python -c "import timm; m = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True); print('下载成功')"