第193期 如何微调大语言模型(LLM)(内含源码细节)

  • 时间:2025-11-22 23:23 作者: 来源: 阅读:0
  • 扫一扫,手机访问
摘要:*** AI拉呱,专注于人工智领域与AI工具、前沿技术解读。*** 我一直认为,大语言模型(LLM)是法律行业的理想工具。法律工作本质上围绕文档展开——阅读、撰写以及解读合同、法规、法院判决等复杂法律文本,而大语言模型的设计初衷正是理解和生成文本。 大语言模型能够快速总结文档、查找相关信息并解答法律问题,这使其成为处理律师日常繁重读写工作的理想选择。它们还能通过自动化处理合同审查、法律研究、诉

*** AI拉呱,专注于人工智领域与AI工具、前沿技术解读。***

我一直认为,大语言模型(LLM)是法律行业的理想工具。法律工作本质上围绕文档展开——阅读、撰写以及解读合同、法规、法院判决等复杂法律文本,而大语言模型的设计初衷正是理解和生成文本。

大语言模型能够快速总结文档、查找相关信息并解答法律问题,这使其成为处理律师日常繁重读写工作的理想选择。它们还能通过自动化处理合同审查、法律研究、诉讼案件中文档筛选等重复性任务节省时间。法律专业人士无需在这些工作上花费数小时,只需借助大语言模型提高效率,将精力集中在更重要的事务上。

大语言模型还可基于特定法律数据(如某国法律或法院案例)进行训练,从而在对应法律体系中变得更加精准实用。而这正是我尝试要做的事。

在本文中,我将向你展示如何利用新加坡全部立法内容微调Llama-3.1–8B-Instruct大语言模型,并使用该模型进行推理生成回复。微调过程主要分为以下3个步骤:

构建微调数据集 从“新加坡法规在线”(Singapore Statutes Online)下载所有成文法及附属立法文件,这些文件均以PDF格式提供提取每个文件的内容,并将其保存到对应的JSON文件中,每个JSON文件包含一项立法内容将JSON文件转换为Alpaca格式的JSONL文件(该格式与Hugging Face数据集工具包兼容),此文件即为我们的微调数据集 微调大语言模型 使用上述微调数据集,借助Unsloth微调工具,通过参数高效微调(PEFT)结合LoRA适配器,对Meta的Llama-3.1–8B-Instruct模型进行微调将LoRA适配器与基础模型合并,生成微调后的模型 使用微调后的模型进行推理 最后,通过推理脚本调用微调后的模型进行推理。

现在,让我们开始吧!

构建微调数据集

构建微调数据集是整个流程的第一步。这绝非一项简单的任务——它需要精心进行数据提取、清洗、格式化,并设计合理的示例生成方式,以确保模型能有效学习。即便模型架构和训练配置无误,结构混乱或不一致的数据也会导致训练结果不理想。

数据集的质量和多样性直接影响模型的泛化能力和准确遵循指令的能力。在法律这类对精准度和上下文要求极高的领域,生成涵盖多种问题类型、答案清晰可靠的示例尤为重要。这一步的工作质量,决定了微调后的模型是否具有实际使用价值。

构建此类数据集需要在自动化处理与人工监督之间找到平衡——自动化脚本可高效生成大量数据,但验证数据的准确性和相关性往往需要人工审核和领域知识支持。在这一阶段投入时间和精力,后续将能收获回报:得到一个不仅精准,而且可靠、符合目标用户需求的模型。

尽管如此,构建这类数据集仍是微调大语言模型过程中最困难、最繁琐的环节,通常会占据总工作量和总耗时的70%至90%。不过,我不会在此处展示构建数据集所用的代码——这些代码篇幅较长,且并非本文重点。总体而言,我会选择最适合自己的语言和库来完成任务:用Python下载PDF文件,用Go处理这些文件。

由于文件数量庞大,我还大量使用大语言模型来处理文件。具体来说,我在笔记本电脑上运行Ollama,并使用gemma3-4B模型进行处理。

下载立法文件

你可在“新加坡法规在线”(Singapore Statutes Online)网站上获取新加坡全部立法内容,该网站由新加坡总检察署立法处(Legislation Division of the Attorney-General’s Chambers of Singapore)运营,向公众免费开放。

尽管网站内容免费,但所有文件均以PDF格式分散存储在不同页面。此外,每项成文法可能还包含附属立法,导致文件分布十分零散。不过,好在这些文件都以格式规范的PDF形式整理归档。

手动下载每个文件无疑会让项目停滞不前。幸运的是,我们可以通过自动化方式完成下载:我只需分析网页HTML结构确定需下载的内容,然后编写脚本即可批量下载PDF文件。

与许多涉及网页爬取或HTML解析的脚本类似,我使用BeautifulSoup库从繁杂的标签中筛选出正确的URL,其余操作则借助标准Python库完成。若一切顺利,最终可下载到3581个PDF文件,每个文件包含一项成文法或附属立法内容。

提取内容

接下来,我们需要提取PDF文件中的内容并将其转换为JSON格式。这实际上是一个中间步骤——我们最终需要的是JSONL格式的微调数据集,但通过JSON格式的中间文件,我们能更直观地检查数据是否存在问题。

然而,将PDF内容转换为JSON格式本身并非易事。原因在于,尽管许多文件格式相似,但并非所有文件都采用统一格式;此外,部分文件篇幅较短,但也有文件内容极其庞大,例如《1947年所得税法》(Income Tax Act 1947)长达1260页,《2001年证券与期货法》(Securities and Futures Act 2001)也有1086页。

大多数法律法案会分为多个部分(part),每个部分又细分为多个条款(section),部分条款还可能进一步分为子条款(subsection)。为简化处理,我仅保留到“条款”层级。也就是说,每个JSON文件的结构大致如下(实际内容会丰富得多):


{
  "title": "INCOME TAX ACT 1947",
  "parts": [
    {
      "title": "PRELIMINARY",
      "sections": [
        {
          "title": "1. Short title",
          "text": "This Act is the Income Tax Act 1947."
        }
      ]
    }
  ]
}

看似简单的操作,实际执行起来却并不顺利。我最终采用的流程是:先将PDF文件转换为文本格式,按“部分”拆分文本,再借助大语言模型将每个“部分”处理为“条款”结构,最后将这些“条款”重新整合。部分文件因格式不统一,甚至需要手动拆分处理。

转换为JSONL格式

Hugging Face数据集格式是一种标准化的数据表示方式,可通过datasets库使用,尤其适用于语言模型的训练和微调。该格式将数据视为一系列示例的集合,每个示例以字典形式呈现,包含“instruction”(指令)、“input”(输入)、“output”(输出)等命名字段。这种格式高效且易于操作,能与transformers、tokenizers、Trainer等其他Hugging Face工具良好兼容。

在指令遵循类任务中,数据集通常包含三个字段:

指令(instruction):告知模型需执行的任务输入(input):任务所需的内容或上下文输出(output):预期的响应结果

这种结构适用于摘要生成、翻译、问答等多种任务。我选择采用Alpaca格式(该格式遵循上述结构),原因在于它简洁高效,被广泛用于指令微调,且得到开源训练脚本的良好支持,同时与现代大语言模型采用的“提示-响应”模式高度契合。

Alpaca格式的流行源于斯坦福大学发布的Alpaca模型——该模型证明,借助少量有监督数据和开源工具,就能实现出色的指令遵循能力。Alpaca格式高度模拟对话式的“提示-响应”结构,非常适合训练实用的通用助手。它之所以被广泛使用,正是因为其简洁直观、兼容Hugging Face工具,且能适配多种任务场景。

Alpaca格式通常以JSONL(JSON Lines)文件格式存储和共享,文件中每一行都是一个独立的JSON对象,代表数据集中的一个示例。这种格式支持逐行处理大型数据集,无需将整个文件加载到内存中,处理效率更高。

以下是JSONL文件的示例:


{"instruction": "Translate to Spanish", "input": "Hello", "output": "Hola"}
{"instruction": "Summarize the paragraph", "input": "Large language models are...", "output": "LLMs are powerful AI tools."}

将JSON文件转换为Hugging Face数据集格式,仍需借助大语言模型。具体策略是:针对之前生成的每个JSON文件,生成大量“指令-输入-输出”组合。这一步看似简单,实则耗时较长。我使用gemma3-4B大语言模型生成指令,效果出人意料地好。

最终,针对3581个PDF文件,我生成了约20万组“指令-输入-输出”示例。部分文件仅对应1-2个问题,而有些文件则对应多达2000条指令。可想而知,整个处理过程花费了相当长的时间。

所有示例生成完成后,我将其整合成结构化数据集,并拆分为训练集(占比90%)和验证集(占比10%)。生成的数据集以多种格式保存:适用于机器学习流程的Hugging Face兼容格式、便于人工检查的JSON格式,以及展示样本数量和分布情况的汇总统计数据。

至此,微调数据集准备就绪!

微调Llama-3.1–8B-Instruct模型

这一步是整个流程的核心环节。看似复杂,实则步骤清晰。但这就像下棋——学会容易,精通难。微调的步骤虽然相对简单,但要得到一个优质、精准的微调模型,难度却很大。

硬件配置

我最近购置了一台用于工作的台式电脑,配置如下:AMD Ryzen 9 9950X处理器、64GB DDR5内存、4TB SSD存储空间。其中最关键的组件是NVIDIA GeForce RTX 5090显卡,该显卡拥有32GB显存,内存带宽达1792GB/s。我正是打算在这台电脑上尝试微调Llama-3.1–8B-Instruct模型。

如果你怀疑这个微调项目只是我想测试新电脑性能的借口,那你完全猜对了。

工具选择

微调技术已不算新颖,实现方式有多种。最兼容的方式或许是直接使用Hugging Face Transformers及其他Hugging Face库,但这种方式速度较慢。由于我仅使用单台台式机的显卡进行微调,因此更倾向于选择更快的方式(事实上,无论何种情况,能快速微调总比慢速微调要好)。

另外两种常用工具是Unsloth和Axolotl。两者均为开源工具,可加速Llama-3.1–8B-Instruct等开源权重模型的微调,且均基于Hugging Face Transformers构建。

Unsloth是Hugging Face Transformers的性能优化封装工具,能将大语言模型微调速度提升两倍,同时大幅降低显卡内存占用,非常适合在消费级显卡上进行指令微调或持续预训练。

而Axolotl则是另一个高度可配置的微调框架,同样基于Hugging Face生态系统构建,适用于复杂的大规模训练流程。

单从功能描述来看,Unsloth似乎是更优选择,但决定性因素在于Axolotl不支持RTX 5090显卡!因此,我最终选择了Unsloth。

环境配置

或许在意料之中,但对我而言仍有些意外:硬件设备与微调软件的关联性极强。我过去长期在软件可跨设备、甚至跨操作系统移植的平台上工作,如今需要应对不同设备的特性差异,既有趣又令人困扰。

此处的主要问题集中在RTX 5090显卡上。

GeForce RTX 5090属于Blackwell架构显卡,计算能力为sm_120(支持CUDA 12.x)。在当前时间(2025年7月),PyTorch的最新稳定版本为2.7.1,虽支持CUDA 12.x,但默认安装仅支持到CUDA 12.6,而RTX 5090显卡需要CUDA 12.8版本。

安装相关库时,建议使用uv工具。执行以下命令安装uv:


curl -LsSf <https://astral.sh/uv/install.sh> | sh && source $HOME/.local/bin/env

随后,创建Python 3.12版本的虚拟环境(若你习惯使用conda,也可选择conda创建):


uv venv .venv --python=3.12 --seed
source .venv/bin/activate

接下来,安装适配的PyTorch版本,注意需指定对应安装包:


uv pip install torch==2.7.1 --index-url <https://download.pytorch.org/whl/cu128>

这可能是目前遇到的最大障碍——解决PyTorch的兼容性问题堪称一场噩梦。更糟糕的是,我最初尝试时,当前稳定版本尚未发布,不得不使用PyTorch 2.8的夜间版本(nightly version)才能正常运行,花了很长时间才解决这一问题。

除PyTorch外,运行Unsloth还需安装一系列其他库:


uv pip install unsloth unsloth_zoo bitsandbytes

其中最棘手的(对我而言)是Triton库。目前,必须使用triton>=3.3.1版本,否则Unsloth无法正常运行。即便你想不使用Triton并禁用它,也无法顺利解决问题——因为Unsloth的部分功能本身依赖Triton,若版本不匹配,将无法与RTX 5090显卡兼容。执行以下命令安装适配的Triton版本:


uv pip install triton>=3.3.1

当然,你还需要安装Hugging Face Transformers,但需使用特定版本:


uv pip install -U transformers==4.52.4

获取基础模型与使用Weights and Biases

我们还需要登录Hugging Face和Weights and Biases两个平台。

登录Hugging Face是因为需要下载用于微调的基础模型。为何必须登录?因为Llama-3.1–8B-Instruct是受限模型,需登录Hugging Face才能加载并用于微调。

只需登录https://huggingface.co/(若没有账号,注册一个即可),然后进入https://huggingface.co/settings/tokens页面创建用户访问令牌。之后返回终端,将令牌设置为环境变量:


$ export HF_TOKEN=<你的用户访问令牌>

至于Weights and Biases,我们主要用它进行日志记录,但它还有许多其他功能。这是一个非常实用的工具,能帮助我们捕获日志(及其他信息),让我们无需登录查看日志就能开展其他工作并监控训练进程。

首先在https://wandb.ai/注册账号并登录,然后前往https://wandb.ai/authorize获取API密钥。之后在命令行中执行以下命令:


$ wandb login

并粘贴获取到的API密钥。此外,也可在运行脚本前设置WANDB_API_KEY环境变量。

微调脚本

以下是一个长达600多行的脚本,却是微调Llama-3.1–8B-Instruct模型的完整脚本。


#!/usr/bin/env python3
"""
基于新加坡立法内容微调Llama 3.1 8B Instruct模型
使用Unsloth工具结合PEFT/LoRA技术
针对RTX 5090(Blackwell架构)显卡优化,适配CUDA 12.8

硬件要求:
- AMD Ryzen 9 9950X处理器、64GB内存、RTX-5090显卡(Blackwell架构)
- 操作系统:Ubuntu 24.04.2 LTS
- CUDA版本:12.8
"""

import os
import sys
import logging
import math
import json
from datetime import datetime
from datasets import load_from_disk, Dataset

import torch
import numpy as np
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import (
    TrainingArguments,
    TextStreamer,
    TrainerCallback,
    EarlyStoppingCallback,
    TrainerControl,
    TrainerState,
)
from peft import LoraConfig
import wandb
import torch.nn.functional as F
from typing import Optional, Union, Dict, Any, List, Tuple

class RTX5090OptimizedTrainer(SFTTrainer):
    """针对RTX 5090显卡优化的自定义SFT训练器,提升稳定性"""

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        支持Triton加速的稳定损失计算
        针对RTX 5090显卡优化,提升性能
        """
        # 获取模型输出
        outputs = model(
            input_ids=inputs.get("input_ids"),
            attention_mask=inputs.get("attention_mask"),
            labels=inputs.get("labels"),
            return_dict=True,
        )

        # 获取对数概率和标签
        logits = outputs.logits
        labels = inputs.get("labels")

        # 位移操作:使第n个token由前n-1个token预测
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # 展平token维度
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        shift_labels = shift_labels.view(-1)

        # 计算损失,确保数值稳定性
        loss = loss_fct(shift_logits, shift_labels)

        # 检测并处理NaN/Inf值
        if torch.isnan(loss) or torch.isinf(loss):
            logger.warning(
                "检测到NaN/Inf损失值,使用备用计算方式"
            )
            loss = torch.tensor(
                0.0, device=loss.device, requires_grad=True
            )

        return (loss, outputs) if return_outputs else loss

class StabilityCallback(TrainerCallback):
    """用于监控训练稳定性和检测训练偏离的回调类"""

    def __init__(
        self,
        divergence_threshold=1.3,
        min_steps_before_check=100,
        patience=2,
        no_improvement_patience=5,
        gradient_explosion_threshold=5.0,
    ):
        super().__init__()
        self.divergence_threshold = divergence_threshold  # 训练偏离阈值
        self.min_steps_before_check = min_steps_before_check  # 开始检查的最小步数
        self.patience = patience  # 训练偏离容忍次数
        self.no_improvement_patience = no_improvement_patience  # 无提升容忍次数
        self.gradient_explosion_threshold = gradient_explosion_threshold  # 梯度爆炸阈值
        self.best_val_loss = float("inf")  # 最佳验证损失
        self.divergence_count = 0  # 训练偏离计数
        self.no_improvement_count = 0  # 无提升计数
        self.validation_losses = []  # 验证损失记录
        self.gradient_norms = []  # 梯度范数记录
        self.consecutive_high_gradients = 0  # 连续高梯度计数

    def on_log(self, args, state, control, logs=None, **kwargs):
        """监控梯度范数和训练稳定性"""
        if logs:
            # 跟踪梯度范数
            if "grad_norm" in logs:
                self.gradient_norms.append(logs["grad_norm"])

                # 检测梯度爆炸
                if logs["grad_norm"] > self.gradient_explosion_threshold:
                    self.consecutive_high_gradients += 1
                    logger.error(
                        f"检测到梯度爆炸:{logs['grad_norm']:.4f} > {self.gradient_explosion_threshold}"
                    )

                    if self.consecutive_high_gradients >= 3:
                        logger.error(
                            "因持续梯度爆炸,停止训练"
                        )
                        control.should_training_stop = True
                        return
                elif logs["grad_norm"] > 1.0:
                    logger.warning(
                        f"检测到高梯度范数:{logs['grad_norm']:.4f}"
                    )
                    self.consecutive_high_gradients = max(
                        0, self.consecutive_high_gradients - 1
                    )
                else:
                    self.consecutive_high_gradients = 0

                # 记录梯度统计信息
                if len(self.gradient_norms) >= 10:
                    recent_norms = self.gradient_norms[-10:]
                    avg_norm = np.mean(recent_norms)
                    max_norm = np.max(recent_norms)

                    wandb.log(
                        {
                            "gradient_norm_avg_10": avg_norm,  # 近10步梯度范数平均值
                            "gradient_norm_max_10": max_norm,  # 近10步梯度范数最大值
                            "gradient_norm_current": logs["grad_norm"],  # 当前梯度范数
                            "consecutive_high_gradients": self.consecutive_high_gradients,  # 连续高梯度计数
                        },
                        step=state.global_step,
                    )

            # 检测训练损失中的NaN/Inf值
            if "train_loss" in logs:
                if np.isnan(logs["train_loss"]) or np.isinf(logs["train_loss"]):
                    logger.error(
                        f"在第{state.global_step}步检测到训练损失为NaN/Inf值"
                    )
                    control.should_training_stop = True

    def on_evaluate(self, args, state, control, logs=None, **kwargs):
        """检测训练偏离并执行早停策略"""
        if (
            logs
            and "eval_loss" in logs
            and state.global_step >= self.min_steps_before_check
        ):
            current_val_loss = logs["eval_loss"]
            self.validation_losses.append(current_val_loss)

            # 检查是否有性能提升
            if current_val_loss < self.best_val_loss:
                improvement = self.best_val_loss - current_val_loss
                self.best_val_loss = current_val_loss
                self.divergence_count = 0
                self.no_improvement_count = 0
                logger.info(
                    f"新的最佳验证损失:{self.best_val_loss:.4f}(提升幅度:{improvement:.4f})"
                )
            else:
                self.no_improvement_count += 1

                # 检查是否出现训练偏离
                if current_val_loss > self.best_val_loss * self.divergence_threshold:
                    self.divergence_count += 1
                    logger.warning(
                        f"检测到潜在训练偏离:{current_val_loss:.4f} > {self.best_val_loss * self.divergence_threshold:.4f}(计数:{self.divergence_count})"
                    )

                    if self.divergence_count >= self.patience:
                        logger.error(
                            f"训练偏离!在第{state.global_step}步停止训练"
                        )
                        control.should_training_stop = True
                        return

                # 检查是否长期无性能提升
                if self.no_improvement_count >= self.no_improvement_patience:
                    logger.warning(
                        f"{self.no_improvement_count}次评估无性能提升,停止训练"
                    )
                    control.should_training_stop = True
                    return

            # 计算困惑度(perplexity)
            perplexity = math.exp(current_val_loss)

            # 扩展日志记录
            wandb.log(
                {
                    "eval_perplexity": perplexity,  # 验证集困惑度
                    "best_val_loss": self.best_val_loss,  # 最佳验证损失
                    "divergence_count": self.divergence_count,  # 训练偏离计数
                    "no_improvement_count": self.no_improvement_count,  # 无提升计数
                    "val_loss_trend": (
                        current_val_loss - self.validation_losses[-2]
                        if len(self.validation_losses) >= 2
                        else 0
                    ),  # 验证损失变化趋势
                },
                step=state.global_step,
            )

            logger.info(
                f"第{state.global_step}步:验证损失:{current_val_loss:.4f},困惑度:{perplexity:.2f},无提升次数:{self.no_improvement_count}"
            )

def print_system_info():
    """打印系统信息"""
    logger.info("=== 系统信息 ===")
    logger.info(f"Python版本:{sys.version}")
    logger.info(f"PyTorch版本:{torch.__version__}")
    logger.info(f"CUDA是否可用:{torch.cuda.is_available()}")

    if torch.cuda.is_available():
        logger.info(f"CUDA版本:{torch.version.cuda}")
        logger.info(f"显卡数量:{torch.cuda.device_count()}")
        logger.info(f"显卡名称:{torch.cuda.get_device_name(0)}")
        logger.info(
            f"显卡内存:{torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB"
        )

    logger.info("=== 系统信息结束 ===")

def format_instruction_chat(example):
    """将示例格式化为适用于指令微调的对话模板"""
    # 从示例中提取内容
    content = example.get("content", example.get("text", ""))

    # 构建法律文档的“指令-响应”格式
    instruction = (
        "你是一位精通新加坡法律的助手。请提供关于以下法律文档的准确信息。"
    )

    # 格式化为对话形式
    parts = [
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>

",
        instruction,
        "<|eot_id|><|start_header_id|>user<|end_header_id|>

",
        "请解读此新加坡法律条文:<|eot_id|>",
        "<|start_header_id|>assistant<|end_header_id|>

",
        content,
        "<|eot_id|>",
    ]
    formatted_text = "".join(parts)

    return {"text": formatted_text}

def load_and_prepare_dataset(dataset_path):
    """加载并预处理新加坡立法数据集"""
    logger.info(f"从{dataset_path}加载数据集")

    try:
        dataset = load_from_disk(dataset_path)
        logger.info(f"数据集加载成功:{dataset}")

        # 应用对话格式处理
        logger.info("应用指令格式处理...")
        dataset = dataset.map(
            format_instruction_chat,
            remove_columns=dataset["train"].column_names,
        )

        # 验证数据集结构
        logger.info(
            f"格式化文本示例:{dataset['train'][0]['text'][:120]}..."
        )

        return dataset

    except Exception as e:
        logger.error(f"数据集加载失败:{e}")
        raise

def find_latest_checkpoint(output_dir):
    """查找输出目录中最新的检查点(checkpoint)"""
    if not os.path.exists(output_dir):
        return None

    checkpoints = []
    for item in os.listdir(output_dir):
        if item.startswith("checkpoint-") and os.path.isdir(
            os.path.join(output_dir, item)
        ):
            try:
                step_num = int(item.split("-")[1])
                checkpoints.append(
                    (step_num, os.path.join(output_dir, item))
                )
            except (ValueError, IndexError):
                continue

    if checkpoints:
        # 返回步数最大的检查点
        latest_step, latest_path = max(checkpoints, key=lambda x: x[0])
        logger.info(
            f"找到最新检查点:{latest_path}(步数:{latest_step})"
        )
        return latest_path

    return None

def main():
    # 固定参数——基于经验总结的保守设置
    # 模型与数据集配置
    model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    dataset_path = "./hf_dataset/sg_legislation_instructions"
    output_dir = "./llama-3.1-8b-instruct-sg-legislation"

    # 训练参数——经实践验证的稳定超参数
    learning_rate = 2e-5  # 保守学习率,确保稳定性
    num_epochs = 2  # 训练轮次
    batch_size = 4  # 批次大小
    gradient_accumulation_steps = 2  # 梯度累积步数
    max_seq_length = 2048  # 最大序列长度
    warmup_steps = 500  # 学习率预热步数
    weight_decay = 0.1  # 强正则化系数
    max_grad_norm = 0.3  # 保守梯度裁剪阈值

    # LoRA参数——为提升稳定性降低参数设置
    lora_r = 16  # 降低秩以提升稳定性
    lora_alpha = 16  # LoRA缩放系数
    lora_dropout = 0.2  # LoRA dropout比例

    # 早停与稳定性设置
    early_stopping_patience = 4  # 早停容忍次数
    eval_steps = 250  # 验证步数间隔
    save_steps = 500  # 保存检查点步数间隔

    # Weights & Biases配置
    wandb_project = "llama-3.1-8b-instruct-sg-legislation"
    wandb_run_name = None  # 运行名称(默认自动生成)

    # 其他参数
    resume_from_checkpoint = None  # 从检查点恢复训练(默认不恢复)
    seed = 42  # 随机种子(确保可复现性)

    # 配置日志记录
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler("finetune_llama.log"),  # 日志写入文件
            logging.StreamHandler(sys.stdout),  # 日志输出到控制台
        ],
    )

    global logger
    logger = logging.getLogger(__name__)

    # 打印系统信息
    print_system_info()

    # 设置随机种子
    torch.manual_seed(seed)
    np.random.seed(seed)

    # 初始化Weights & Biases
    run_name = (
        wandb_run_name
        or f"llama-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    )
    wandb.init(
        project=wandb_project,
        name=run_name,
        config={
            "model_name": model_name,
            "dataset_path": dataset_path,
            "learning_rate": learning_rate,
            "num_epochs": num_epochs,
            "batch_size": batch_size,
            "gradient_accumulation_steps": gradient_accumulation_steps,
            "max_seq_length": max_seq_length,
            "warmup_steps": warmup_steps,
            "weight_decay": weight_decay,
            "max_grad_norm": max_grad_norm,
            "lora_r": lora_r,
            "lora_alpha": lora_alpha,
            "lora_dropout": lora_dropout,
            "early_stopping_patience": early_stopping_patience,
            "seed": seed,
        },
    )

    logger.info(
        "开始针对RTX 5090显卡微调Llama 3.1 8B Instruct模型..."
    )
    logger.info(
        f"固定参数:模型={model_name},学习率={learning_rate},训练轮次={num_epochs}"
    )

    # 自动检测检查点(若未指定恢复路径)
    if resume_from_checkpoint is None:
        latest_checkpoint = find_latest_checkpoint(output_dir)
        if latest_checkpoint:
            logger.info(
                f"自动从最新检查点恢复训练:{latest_checkpoint}"
            )
            resume_from_checkpoint = latest_checkpoint
        else:
            logger.info(
                "未找到现有检查点,从头开始训练"
            )
    elif resume_from_checkpoint and not os.path.exists(resume_from_checkpoint):
        logger.warning(
            f"未找到指定检查点{resume_from_checkpoint},从头开始训练"
        )
        resume_from_checkpoint = None

    # 加载数据集
    dataset = load_and_prepare_dataset(dataset_path)

    # 借助Unsloth优化加载模型和分词器
    logger.info("借助Unsloth加载模型和分词器...")
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_name,
        max_seq_length=max_seq_length,
        dtype=None,  # 自动检测数据类型
        load_in_4bit=True,  # 针对RTX 5090使用4位量化
        device_map="auto",  # 自动分配设备
    )

    # 配置LoRA
    model = FastLanguageModel.get_peft_model(
        model,
        r=lora_r,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        bias="none",
        use_gradient_checkpointing="unsloth",  # 使用Unsloth的梯度检查点
        random_state=seed,
    )

    # 配置训练参数
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        max_grad_norm=max_grad_norm,
        warmup_steps=warmup_steps,
        lr_scheduler_type="cosine",  # 稳定的余弦学习率调度器
        logging_steps=50,  # 日志记录步数间隔
        eval_steps=eval_steps,  # 验证步数间隔
        save_steps=save_steps,  # 保存检查点步数间隔
        eval_strategy="steps",  # 按步数进行验证(原参数名:evaluation_strategy)
        save_strategy="steps",  # 按步数保存检查点
        load_best_model_at_end=True,  # 训练结束后加载最佳模型
        metric_for_best_model="eval_loss",  # 以验证损失为最佳模型评价指标
        greater_is_better=False,  # 验证损失越小越好
        report_to="wandb",  # 日志上报至wandb
        run_name=run_name,  # 运行名称
        seed=seed,  # 随机种子
        data_seed=seed,  # 数据随机种子
        # RTX 5090优化设置
        bf16=True,  # 使用bf16精度匹配模型
        fp16=False,  # 禁用fp16
        dataloader_pin_memory=True,  # 启用数据加载器内存锁定
        dataloader_num_workers=4,  # 数据加载器工作进程数
        remove_unused_columns=False,  # 不删除未使用列
        # 稳定性提升设置
        save_safetensors=True,  # 以safetensors格式保存模型
        ddp_find_unused_parameters=False,  # 禁用分布式训练中未使用参数检查
    )

    # 初始化回调函数,增强稳定性监控
    stability_callback = StabilityCallback(
        divergence_threshold=1.3,  # 更保守的训练偏离阈值
        min_steps_before_check=100,  # 更早开始检查
        patience=2,  # 训练偏离容忍次数更少
        no_improvement_patience=5,  # 无提升时停止训练
        gradient_explosion_threshold=5.0,  # 梯度超过5.0时停止训练
    )

    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=early_stopping_patience,
    )

    # 初始化训练器
    logger.info("初始化针对RTX 5090优化的训练器...")
    trainer = RTX5090OptimizedTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        dataset_text_field="text",
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        packing=True,  # 启用数据打包
        dataset_kwargs={
            "add_special_tokens": False,
            "append_concat_token": False,
        },
        callbacks=[stability_callback, early_stopping_callback],
    )

    # 记录训练信息
    logger.info(f"训练样本数量:{len(dataset['train'])}")
    logger.info(f"验证样本数量:{len(dataset['validation'])}")
    logger.info(
        f"有效批次大小:{batch_size * gradient_accumulation_steps}"
    )

    # 开始训练
    logger.info("开始训练...")
    try:
        trainer.train(resume_from_checkpoint=resume_from_checkpoint)
        logger.info("训练成功完成!")
    except Exception as e:
        logger.error(f"训练失败:{e}")
        raise

    # 保存模型
    logger.info("保存模型...")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    # 借助Unsloth保存合并后的模型
    logger.info("保存合并后的模型...")
    try:
        model.save_pretrained_merged(
            output_dir,
            tokenizer,
            save_method="merged_16bit",
        )
        logger.info("合并后的模型保存成功!")
    except Exception as e:
        logger.warning(f"合并后的模型保存失败:{e}")

    # 保存训练配置
    config_path = os.path.join(output_dir, "training_config.json")
    with open(config_path, "w") as f:
        json.dump(
            {
                "model_name": model_name,
                "dataset_path": dataset_path,
                "training_args": training_args.to_dict(),
                "lora_config": {
                    "r": lora_r,
                    "alpha": lora_alpha,
                    "dropout": lora_dropout,
                },
                "environment": {
                    "torch_version": torch.__version__,
                    "cuda_version": torch.version.cuda,
                    "gpu_name": (
                        torch.cuda.get_device_name(0)
                        if torch.cuda.is_available()
                        else "N/A"
                    ),
                    "optimized_for": "RTX-5090-Blackwell",
                },
            },
            f,
            indent=2,
        )

    logger.info(f"模型已保存至{output_dir}")
    logger.info("微调成功完成!")

    # 结束wandb运行
    wandb.finish()

if __name__ == "__main__":
    main()

该脚本使用了多个库,核心包括Unsloth及其快速模型加载功能、用于LoRA支持的Hugging Face Transformers和PEFT库、用于有监督训练的TRL库(含SFTTrainer),还使用wandb进行实验跟踪,以及datasets库加载和管理本地数据集。

我编写了几个实用函数: find_latest_checkpoint用于自动从最新检查点恢复训练, format_instruction_chat用于将原始示例格式化为“指令-响应”对话形式, load_and_prepare_dataset()用于加载数据集, print_system_info则在启动时打印系统配置信息。这些抽象设计让主脚本结构更清晰。

此外,还有两个重要类对训练框架进行了扩展。第一个是 RTX5090OptimizedTrainer,它继承自 SFTTrainer并覆写了损失计算方法,可检测并处理NaN/Inf等不稳定数值,避免训练崩溃;第二个是 StabilityCallback,通过训练偏离检测、基于性能停滞的早停、梯度范数爆炸检测等机制监控训练状态,与Hugging Face内置的 EarlyStoppingCallback配合使用,确保即使在长时间或不稳定的微调过程中,训练也能稳定进行。

main函数统筹整个微调流程:首先声明超参数并记录到wandb,配置日志记录,设置随机种子以确保可复现性,打印系统信息到控制台和文件;若启用检查点功能或有必要,脚本会自动从最新保存的状态恢复训练。

以下是所用超参数列表:

learning_rate = 2e-5:保守的学习率,控制训练过程中模型参数的更新速度num_epochs = 2:训练数据集的完整遍历次数batch_size = 4:每步训练同时处理的示例数量gradient_accumulation_steps = 2:累积2步梯度后再更新权重,等效于批次大小为8max_seq_length = 2048:单个训练示例(输入+输出)的最大token数量warmup_steps = 500:从0逐步提升至目标学习率的步数weight_decay = 0.1:通过惩罚大权重实现L2正则化,防止过拟合max_grad_norm = 0.3:梯度裁剪阈值,防止更新不稳定或梯度爆炸lora_r = 16:低秩适应(LoRA)矩阵的维度(秩),数值越小,内存占用越少,稳定性越高lora_alpha = 16:调整LoRA更新相对于基础模型权重的强度缩放因子lora_dropout = 0.2:应用于LoRA适配器的dropout比例,用于减少过拟合

接下来,从本地加载数据集,将其格式化为适用于对话模型的指令式提示,拆分为训练集和验证集。然后通过Unsloth的 FastLanguageModel.from_pretrained加载模型,该方法会应用4位量化并自动将模型分配到可用显卡内存中。借助 get_peft_model注入LoRA适配器,目标模块包括 q_proj v_proj gate_proj等投影矩阵。

训练参数通过Hugging Face的 TrainingArguments配置,包括混合精度(bfloat16)、余弦学习率衰减、频繁验证、自动检查点等设置。随后初始化自定义的 RTX5090OptimizedTrainer,传入模型、数据集、分词器和回调函数,调用 trainer.train启动训练;若启用相关功能,模型会根据验证损失自动加载最优检查点。

训练完成后,会执行以下操作: trainer.save_model保存微调过程中的轻量级适配器权重,保留任务特定参数与原始基础模型的分离; tokenizer.save_pretrained保存分词器资源(词汇表、配置、新增特殊token),确保后续推理流程能复现完全一致的分词环境,生成 tokenizer.json tokenizer_config.json special_tokens_map.json三个文件; model.save_pretrained_merged将基础模型与适配器权重融合,以16位精度导出,生成可直接部署或共享的完整独立模型;最后,将所有配置信息写入 training_config.json文件。

运行脚本

脚本编写完成后,就可以开始运行了。建议在后台运行脚本——尤其是通过SSH远程连接到机器时,避免因连接断开导致微调中断,不得不重新开始。

实现后台运行的方法有多种,最简单的是使用tmux。若未安装tmux,先正常安装,然后创建新会话:


$ tmux new -s finetune

之后即可运行脚本:


$ python finetune_llama.py

任何时候都可按 ctrl+b再按 d断开会话;若需重新连接会话,执行以下命令:


$ tmux attach -t finetune

以下是脚本运行时的初始输出示例:


sausheong@riptide$ source venv/bin/activate
(venv) sausheong@riptide$ python finetune_llama.py
🦥 Unsloth:将优化你的电脑,实现2倍速免费微调。
🦥 Unsloth Zoo将优化所有组件,提升训练速度!
INFO 07-13 11:23:31 [__init__.py:244] 自动检测到CUDA平台。
2025-07-13 11:23:32,919 - INFO - === 系统信息 ===
2025-07-13 11:23:32,919 - INFO - Python版本:3.12.9 | packaged by Anaconda, Inc. | (main, Feb  6 2025, 18:56:27) [GCC 11.2.0]
2025-07-13 11:23:32,919 - INFO - PyTorch版本:2.7.0+cu128
2025-07-13 11:23:32,919 - INFO - CUDA是否可用:True
2025-07-13 11:23:32,919 - INFO - CUDA版本:12.8
2025-07-13 11:23:32,919 - INFO - 显卡数量:1
2025-07-13 11:23:32,920 - INFO - 显卡名称:NVIDIA GeForce RTX 5090
2025-07-13 11:23:32,920 - INFO - 显卡内存:31.4 GB
2025-07-13 11:23:32,920 - INFO - === 系统信息结束 ===
wandb:当前登录用户:sausheong(sausheongchang),登录地址:<https://api.wandb.ai>。使用`wandb login --relogin`可强制重新登录
wandb:使用0.21.0版本的wandb跟踪运行
wandb:运行数据本地保存路径:/home/sausheong/finetune_sg_legislation/wandb/run-20250713_112333-70utt9lj
wandb:执行`wandb offline`可关闭同步功能
wandb:正在同步运行llama-3.1-rtx5090-20250713-112332
wandb:⭐️ 查看项目:<https://wandb.ai/sausheongchang/llama-3.1-sg-legislation>
wandb:🚀 查看运行:<https://wandb.ai/sausheongchang/llama/runs/70utt9lj>
2025-07-13 11:23:34,495 - INFO - 开始针对RTX 5090显卡微调Llama 3.1 8B Instruct模型...
2025-07-13 11:23:34,495 - INFO - 固定参数:模型=meta-llama/Meta-Llama-3.1-8B-Instruct,学习率=2e-05,训练轮次=2
2025-07-13 11:23:34,495 - INFO - 未找到现有检查点,从头开始训练
2025-07-13 11:23:34,495 - INFO - 从./hf_dataset/sg_legislation_instructions加载数据集
2025-07-13 11:23:34,501 - INFO - 数据集加载成功:DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 177832
    })
    validation: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 22229
    })
    test: Dataset({
        features: ['instruction', 'input', 'output'],
        num_rows: 22229
    })
})
2025-07-13 11:23:34,501 - INFO - 应用指令格式处理...
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 177832/177832 [00:01<00:00, 140056.62 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 22229/22229 [00:00<00:00, 115865.30 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 22229/22229 [00:00<00:00, 138279.64 examples/s]
2025-07-13 11:23:36,126 - INFO - 格式化文本示例:<|begin_of_text|><|start_header_id|>system<|end_header_id|>

你是一位精通新加坡法律的助手...
2025-07-13 11:23:36,126 - INFO - 借助Unsloth加载模型和分词器...
==((====))==  Unsloth 2025.6.12:快速Llama模型优化。Transformers版本:4.52.4。vLLM版本:0.9.2rc2.dev29+gf73d02aad。
   \   /|    NVIDIA GeForce RTX 5090显卡。显卡数量=1。最大内存:31.367 GB。平台:Linux。
O^O/ \_/     PyTorch版本:2.7.0+cu128。CUDA版本:12.0。CUDA工具包版本:12.8。Triton版本:3.3.1
        /    Bfloat16精度=启用。FA [Xformers=无。FA2=禁用]
 "-____-"     免费许可:<http://github.com/unslothai/unsloth>
Unsloth:启用快速下载功能——忽略红色的下载进度条!
2025-07-13 11:23:41,423 - INFO - 将使用设备0上90%的内存存储模型,10%作为缓冲区以避免内存不足(OOM)。你可自行设置`max_memory`参数以使用更多内存(风险自负)。
Unsloth:支持dropout=0的快速优化。当前使用dropout=0.2。
Unsloth将优化除LoRA矩阵外的所有层,可能会对性能造成一定影响。
Unsloth 2025.6.12已优化32层,其中QKV层0个、O层0个、MLP层0个。
2025-07-13 11:23:50,215 - INFO - 初始化针对RTX 5090优化的训练器...
Unsloth:对["text"]进行分词:100%|████████████████████████████████████████████████████████████████████████████| 177832/177832 [00:02<00:00, 70938.77 examples/s]
Unsloth:对["text"]进行分词:100%|██████████████████████████████████████████████████████████████████████████████| 22229/22229 [00:00<00:00, 86162.61 examples/s]
2025-07-13 11:23:53,202 - INFO - 训练样本数量:177832
2025-07-13 11:23:53,202 - INFO - 验证样本数量:22229
2025-07-13 11:23:53,202 - INFO - 有效批次大小:8
2025-07-13 11:23:53,202 - INFO - 开始训练...
==((====))==  Unsloth——2倍速免费微调 | 使用显卡数量=1
   \   /|    样本总数=177,832 | 训练轮次=2 | 总步数=44,458
O^O/ \_/     单设备批次大小=4 | 梯度累积步数=2
        /    数据并行显卡数量=1 | 总批次大小(4×2×1)=8
 "-____-"     可训练参数=41,943,040 / 8,000,000,000(0.52%可训练)
  0%|                                                                                                                                  | 0/44458 [00:00<?, ?it/s]
  0%|                                                                                                                       | 33/44458 [00:18<6:45:46,  1.82it/s]

等待微调完成

现在可以稍作休息(也只是相对而言)。你可以随时重新连接tmux会话查看训练进度,也可直接登录https://www.wandb.ai,进入对应项目查看可视化数据或日志。你完全可以启动微调后去吃午饭或处理其他事务,远程监控机器上的训练情况。微调可能需要数小时甚至数天,而本次微调我大约用了6小时完成。

微调完成后,输出目录中会生成类似以下的文件(同时包含多个检查点目录):


(venv) sausheong@riptide$ ls -lh
...
-rw-rw-r--  1 sausheong sausheong 1008 7月  13 13:42 config.json
-rw-rw-r--  1 sausheong sausheong  234 7月  13 13:42 generation_config.json
-rw-rw-r--  1 sausheong sausheong 4.7G 7月  13 13:44 model-00001-of-00004.safetensors
-rw-rw-r--  1 sausheong sausheong 4.7G 7月  13 13:46 model-00002-of-00004.safetensors
-rw-rw-r--  1 sausheong sausheong 4.6G 7月  13 13:48 model-00003-of-00004.safetensors
-rw-rw-r--  1 sausheong sausheong 1.1G 7月  13 13:48 model-00004-of-00004.safetensors
-rw-rw-r--  1 sausheong sausheong  24K 7月  13 13:42 model.safetensors.index.json
-rw-rw-r--  1 sausheong sausheong 1.9K 7月  13 13:42 README.md
-rw-rw-r--  1 sausheong sausheong  454 7月  13 13:42 special_tokens_map.json
-rw-rw-r--  1 sausheong sausheong  50K 7月  13 13:42 tokenizer_config.json
-rw-rw-r--  1 sausheong sausheong  17M 7月  13 13:42 tokenizer.json
-rw-rw-r--  1 sausheong sausheong 6.1K 7月  13 13:42 training_args.bin
-rw-rw-r--  1 sausheong sausheong 4.7K 7月  13 13:48 training_config.json

其中, config.json文件存储模型架构细节(如层数、隐藏层大小、注意力头数量等),这对推理阶段重建模型至关重要。完整的合并模型权重分散存储在4个文件中——从 model-00001-of-00004.safetensors model-00004-of-00004.safetensors,这些文件包含训练后的参数,由 model.safetensors.index.json文件索引(该文件作为分片权重文件的指针,让模型加载器如Hugging Face的 AutoModelForCausalLM知道如何正确加载这些权重)。

分词也是推理过程中的关键环节,由多个文件共同处理: tokenizer.json包含序列化的分词器词汇表及其内部结构; tokenizer_config.json补充分词器设置(如是否将文本小写、使用何种分词器类等); special_tokens_map.json将BOS(序列开始)、EOS(序列结束)、PAD(填充)等关键token映射到对应的token ID,确保分词器和模型能正确处理这些特殊token。 generation_config.json是可选但实用的文件,存储温度(temperature)、top-k、top-p等默认文本生成参数,虽非必需,但能确保生成文本时的行为一致性。

其余文件则不用于推理: README.md通常记录模型用途、使用方法及重要说明; training_args.bin是Hugging Face TrainingArguments的序列化文件,便于重新训练或验证训练设置; training_config.json则包含前文提到的超参数的可读版本。这些文件对文档记录或重新训练有帮助,但推理时无需使用。

可能出现的问题

当然,上述情况仅在一切顺利时发生。微调过程中其实可能出现多种问题。

微调过程中最常见的问题之一与数据本身相关。数据质量差(如格式不一致、文本未清洗、示例标签错误)会让模型困惑,影响学习效果。若数据集格式与模型预期的“输入-输出”结构不匹配(如指令微调场景),训练可能无报错但生成的模型无法使用。

序列长度超过模型最大token限制,可能导致文本被无声截断甚至训练崩溃。此外,数据量过少可能导致欠拟合,使模型无法泛化;数据泄露(验证集与训练集重叠)则会导致性能指标虚高,误导判断。

训练配置是另一常见问题来源。学习率过高可能导致训练损失骤升或完全偏离,过低则可能使训练陷入停滞。

批次大小不当也会引发不稳定性:批次过大可能导致显卡内存溢出,过小则会使训练过程波动大、效率低。

训练时间过长或数据集过小时,模型可能过拟合——在训练集上表现良好,但在未见过的数据上表现糟糕;反之,若模型容量不足或训练时间不够,无法捕捉数据中的规律,则会出现欠拟合,导致训练集和验证集损失均居高不下。

硬件和环境相关问题也可能导致训练失败。最常见的是显卡内存不足,尤其在微调大型模型或使用长序列时更容易发生。使用错误的CUDA版本或不兼容的PyTorch版本,可能导致模型初始化失败或运行时报错。甚至一些不那么明显的问题(如显卡热节流)也会在无明确提示的情况下大幅降低训练速度。

这些问题都凸显了从检查点恢复训练的重要性——没人希望每次遇到问题都要从头开始训练。

使用模型进行推理

微调结束后,脚本已将适配器与基础模型合并,因此我们只需编写一个常规的推理脚本即可。以下是一个使用合并后模型的简单推理脚本:


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import argparse
import readline

def load_model(model_path, device):
    """
    从指定路径加载模型和分词器
    """
    print(f"从{model_path}加载模型...")
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # CPU环境可能需要使用不同精度加载
    torch_dtype = (
        torch.bfloat16
        if device == "cuda" and torch.cuda.is_bf16_supported()
        else torch.float16
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch_dtype,
        device_map="auto",  # 由transformers自动处理设备分配
    )
    print("模型加载成功。")
    return model, tokenizer

def format_prompt(prompt):
    """
    按Llama 3的指令格式处理提示词
    """
    messages = [
        {
            "role": "system",
            "content": (
                "你是一位精通新加坡法律的助手。"
            ),
        },
        {"role": "user", "content": prompt},
    ]

    # 使用分词器的对话模板实现最可靠的格式化
    formatted_prompt = AutoTokenizer.from_pretrained(
        "./llama-3.1-8b-instruct-sg-legislation/checkpoint-2500"
    ).apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return formatted_prompt

def generate_response(
    model, tokenizer, prompt, device, max_new_tokens=512, temperature=0.7
):
    """
    生成模型响应
    """
    formatted_prompt = format_prompt(prompt)
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(
        model.device
    )

    streamer = TextStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    # 生成响应
    _ = model.generate(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        # 遇到以下token时停止生成
        eos_token_id=[
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>"),
        ],
        pad_token_id=tokenizer.eos_token_id,
    )
    return ""  # 响应由streamer直接输出到标准输出(stdout)

def main():
    """
    运行推理脚本的主函数
    """
    parser = argparse.ArgumentParser(
        description="用于微调后Llama 3.1模型的推理脚本"
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default="./llama-3.1-8b-instruct-sg-legislation",
        help="微调后模型的路径",
    )
    parser.add_argument(
        "--prompt", type=str, default=None, help="待推理的单个提示词"
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=1024,
        help="生成新token的最大数量",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.7,
        help="采样温度(控制生成文本的随机性)",
    )
    args = parser.parse_args()

    # 确定运行设备
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"使用设备:{device}")

    # 加载模型和分词器
    model, tokenizer = load_model(args.model_path, device)

    if args.prompt:
        # 单提示词模式
        print(f"提示词:{args.prompt}")
        print("
响应:")
        generate_response(
            model,
            tokenizer,
            args.prompt,
            device,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
        )
        print("
" + "=" * 80 + "
")
    else:
        # 交互模式
        print("进入交互模式。输入'exit'或'quit'退出。")
        try:
            while True:
                prompt = input("提示词:")
                if prompt.lower() in ["exit", "quit"]:
                    break
                print("
响应:")
                generate_response(
                    model,
                    tokenizer,
                    prompt,
                    device,
                    max_new_tokens=args.max_new_tokens,
                    temperature=args.temperature,
                )
                print("
" + "=" * 80 + "
")
        except KeyboardInterrupt:
            print("
正在退出...")
        except Exception as e:
            print(f"发生错误:{e}")

if __name__ == "__main__":
    main()

该推理脚本仅使用Hugging Face库,因此可在更多环境中运行。它能加载合并后的模型,支持单提示词推理或交互模式推理。

load_model函数初始化并加载分词器和模型; format_prompt函数将用户输入包装成适合Llama 3对话生成的结构化格式,先定义系统提示(告知模型需作为精通新加坡法律的助手),再调用分词器的 apply_chat_template方法生成格式正确的提示词字符串。

generate_response函数负责实际推理:先格式化提示词并将其分词为张量,将输入转移到模型所在设备,设置 TextStreamer实现响应的流式输出;随后调用模型的 generate方法,传入最大新token数量、采样温度等参数,并指定序列结束(EOS)token以控制生成停止时机。由于响应已通过streamer直接输出到标准输出,因此该函数返回空字符串。

运行推理脚本

让我们来测试一下这个脚本!


$ python infer.py --prompt "解读《反贪法》第10条中关于腐败性促成或撤回投标的犯罪行为,包括“代理人”的定义及规定的处罚,以2023年1月1日生效的法律版本为准"
使用设备:cuda
从./llama-3.1-8b-instruct-sg-legislation加载模型...
模型加载成功。
提示词:解读《反贪法》第10条中关于腐败性促成或撤回投标的犯罪行为,包括“代理人”的定义及规定的处罚,以2023年1月1日生效的法律版本为准

响应:
新加坡《反贪法》第10条针对腐败性促成或撤回投标的犯罪行为作出规定。截至2023年1月1日,该条款内容如下:

**第10条 腐败性促成或撤回投标**

(1)任何人如为自身或他人利益,以腐败方式促成或协助促成任何投标、报价、要约或接受(涉及为获取工程、服务、货物或其他目的而与政府、政府部门或政府官员/雇员签订或拟签订的合同),或撤回或试图撤回该等投标、报价、要约或接受,即属犯罪。

(2)任何人如为自身或他人利益,以腐败方式撤回或试图撤回任何投标、报价、要约或接受(涉及为获取工程、服务、货物或其他目的而与政府、政府部门或政府官员/雇员签订或拟签订的合同),即属犯罪。

(3)曾担任或现任政府或政府部门官员/雇员的人士,如为自身或他人利益,促成或协助促成任何投标、报价、要约或接受(涉及为获取工程、服务、货物或其他目的而与政府、政府部门签订或拟签订的合同),或撤回或试图撤回该等投标、报价、要约或接受,即属犯罪。

**“代理人”的定义**

在本条语境下,“代理人”包括代表他人与政府或政府部门打交道的任何人,也包括曾担任或现任政府或政府部门官员/雇员的人士。

**规定的处罚**

《反贪法》第10条规定的犯罪处罚如下:
- 初犯者:最高可处罚款10万新元或监禁最长5年,或两者并罚。
- 再犯或多次犯罪者:最高可处罚款20万新元或监禁最长7年,或两者并罚。

请注意,上述处罚可能会发生变更,建议查阅最新法律法规以获取最准确的信息。

成功了!

关于微调的一些思考

对于微调是否困难,人们一直存在不同看法,这让未实际尝试过的人感到困惑。现在你已和我一同经历了(部分)微调过程,你有什么看法呢?

我个人认为,微调既简单又困难(听起来像是敷衍的说法,但确实如此)。说它简单,是因为步骤相对清晰;但要确定合适的参数和配置,却可能让人不知所措,需要大量试错。我曾看过一档名为《乐高大师》(Lego Masters)的电视节目,惊叹于乐高大师仅用塑料积木就能创造出极具创意的作品。微调大语言模型的过程,给我的感觉与此类似。

按回车键或点击即可查看完整尺寸图片

图片:内森·萨瓦亚(Nathan Sawaya)的作品《黄色》(Yellow),该作品属于“积木的艺术”(The Art of the Brick)展览,展出的所有精致雕塑均由乐高积木搭建而成。

然而,尽管模型微调过程无误,初步看生成的回答也正确,但仔细检查后会发现,其表现并不比基础模型好多少。这可能由多种原因导致。

首先,Llama-3.1–8B模型的容量可能不足以存储和检索精准的法规事实,尤其是我使用的LoRA技术仅更新了模型权重的一小部分。这意味着模型仍可能依赖预训练阶段习得的通用语言模式,

关注“AI拉呱”,评论+转发此文即可私信获取一份教程+一份学习书单!

  • 全部评论(0)
最新发布的资讯信息
【系统环境|】八股已死、场景当立(场景篇-设计模式篇)(2025-11-22 23:27)
【系统环境|】群、环、域(2025-11-22 23:26)
【系统环境|】深度解析:基于Python的分布式缓存系统实现与性能优化(2025-11-22 23:26)
【系统环境|】TP区块链下载全解析:从技术原理到代码实现(2025-11-22 23:25)
【系统环境|】大模型在急性肾衰竭预测及临床方案制定中的应用研究(2025-11-22 23:25)
【系统环境|】特价股票投资中的可持续供应链管理整合方法(2025-11-22 23:24)
【系统环境|】第193期 如何微调大语言模型(LLM)(内含源码细节)(2025-11-22 23:23)
【系统环境|】用Python构建智能推荐系统:技术赋能美好生活(2025-11-22 23:23)
【系统环境|】企业估值中的氢能源应用评估(2025-11-22 23:22)
【系统环境|】ansible 学习之路(2025-11-22 23:22)
手机二维码手机访问领取大礼包
返回顶部