🤖
hacktricks
  • 👾Welcome!
    • HackTricks
    • HackTricks Values & FAQ
    • About the author
  • 🤩Generic Methodologies & Resources
    • Pentesting Methodology
    • External Recon Methodology
      • Wide Source Code Search
      • Github Dorks & Leaks
    • Pentesting Network
      • DHCPv6
      • EIGRP Attacks
      • GLBP & HSRP Attacks
      • IDS and IPS Evasion
      • Lateral VLAN Segmentation Bypass
      • Network Protocols Explained (ESP)
      • Nmap Summary (ESP)
      • Pentesting IPv6
      • WebRTC DoS
      • Spoofing LLMNR, NBT-NS, mDNS/DNS and WPAD and Relay Attacks
      • Spoofing SSDP and UPnP Devices with EvilSSDP
    • Pentesting Wifi
      • Evil Twin EAP-TLS
    • Phishing Methodology
      • Clone a Website
      • Detecting Phishing
      • Phishing Files & Documents
    • Basic Forensic Methodology
      • Baseline Monitoring
      • Anti-Forensic Techniques
      • Docker Forensics
      • Image Acquisition & Mount
      • Linux Forensics
      • Malware Analysis
      • Memory dump analysis
        • Volatility - CheatSheet
      • Partitions/File Systems/Carving
        • File/Data Carving & Recovery Tools
      • Pcap Inspection
        • DNSCat pcap analysis
        • Suricata & Iptables cheatsheet
        • USB Keystrokes
        • Wifi Pcap Analysis
        • Wireshark tricks
      • Specific Software/File-Type Tricks
        • Decompile compiled python binaries (exe, elf) - Retreive from .pyc
        • Browser Artifacts
        • Deofuscation vbs (cscript.exe)
        • Local Cloud Storage
        • Office file analysis
        • PDF File analysis
        • PNG tricks
        • Video and Audio file analysis
        • ZIPs tricks
      • Windows Artifacts
        • Interesting Windows Registry Keys
    • Brute Force - CheatSheet
    • Python Sandbox Escape & Pyscript
      • Bypass Python sandboxes
        • LOAD_NAME / LOAD_CONST opcode OOB Read
      • Class Pollution (Python's Prototype Pollution)
      • Python Internal Read Gadgets
      • Pyscript
      • venv
      • Web Requests
      • Bruteforce hash (few chars)
      • Basic Python
    • Exfiltration
    • Tunneling and Port Forwarding
    • Threat Modeling
    • Search Exploits
    • Reverse Shells (Linux, Windows, MSFVenom)
      • MSFVenom - CheatSheet
      • Reverse Shells - Windows
      • Reverse Shells - Linux
      • Full TTYs
  • 🐧Linux Hardening
    • Checklist - Linux Privilege Escalation
    • Linux Privilege Escalation
      • Arbitrary File Write to Root
      • Cisco - vmanage
      • Containerd (ctr) Privilege Escalation
      • D-Bus Enumeration & Command Injection Privilege Escalation
      • Docker Security
        • Abusing Docker Socket for Privilege Escalation
        • AppArmor
        • AuthZ& AuthN - Docker Access Authorization Plugin
        • CGroups
        • Docker --privileged
        • Docker Breakout / Privilege Escalation
          • release_agent exploit - Relative Paths to PIDs
          • Docker release_agent cgroups escape
          • Sensitive Mounts
        • Namespaces
          • CGroup Namespace
          • IPC Namespace
          • PID Namespace
          • Mount Namespace
          • Network Namespace
          • Time Namespace
          • User Namespace
          • UTS Namespace
        • Seccomp
        • Weaponizing Distroless
      • Escaping from Jails
      • euid, ruid, suid
      • Interesting Groups - Linux Privesc
        • lxd/lxc Group - Privilege escalation
      • Logstash
      • ld.so privesc exploit example
      • Linux Active Directory
      • Linux Capabilities
      • NFS no_root_squash/no_all_squash misconfiguration PE
      • Node inspector/CEF debug abuse
      • Payloads to execute
      • RunC Privilege Escalation
      • SELinux
      • Socket Command Injection
      • Splunk LPE and Persistence
      • SSH Forward Agent exploitation
      • Wildcards Spare tricks
    • Useful Linux Commands
    • Bypass Linux Restrictions
      • Bypass FS protections: read-only / no-exec / Distroless
        • DDexec / EverythingExec
    • Linux Environment Variables
    • Linux Post-Exploitation
      • PAM - Pluggable Authentication Modules
    • FreeIPA Pentesting
  • 🍏MacOS Hardening
    • macOS Security & Privilege Escalation
      • macOS Apps - Inspecting, debugging and Fuzzing
        • Objects in memory
        • Introduction to x64
        • Introduction to ARM64v8
      • macOS AppleFS
      • macOS Bypassing Firewalls
      • macOS Defensive Apps
      • macOS GCD - Grand Central Dispatch
      • macOS Kernel & System Extensions
        • macOS IOKit
        • macOS Kernel Extensions & Debugging
        • macOS Kernel Vulnerabilities
        • macOS System Extensions
      • macOS Network Services & Protocols
      • macOS File Extension & URL scheme app handlers
      • macOS Files, Folders, Binaries & Memory
        • macOS Bundles
        • macOS Installers Abuse
        • macOS Memory Dumping
        • macOS Sensitive Locations & Interesting Daemons
        • macOS Universal binaries & Mach-O Format
      • macOS Objective-C
      • macOS Privilege Escalation
      • macOS Process Abuse
        • macOS Dirty NIB
        • macOS Chromium Injection
        • macOS Electron Applications Injection
        • macOS Function Hooking
        • macOS IPC - Inter Process Communication
          • macOS MIG - Mach Interface Generator
          • macOS XPC
            • macOS XPC Authorization
            • macOS XPC Connecting Process Check
              • macOS PID Reuse
              • macOS xpc_connection_get_audit_token Attack
          • macOS Thread Injection via Task port
        • macOS Java Applications Injection
        • macOS Library Injection
          • macOS Dyld Hijacking & DYLD_INSERT_LIBRARIES
          • macOS Dyld Process
        • macOS Perl Applications Injection
        • macOS Python Applications Injection
        • macOS Ruby Applications Injection
        • macOS .Net Applications Injection
      • macOS Security Protections
        • macOS Gatekeeper / Quarantine / XProtect
        • macOS Launch/Environment Constraints & Trust Cache
        • macOS Sandbox
          • macOS Default Sandbox Debug
          • macOS Sandbox Debug & Bypass
            • macOS Office Sandbox Bypasses
        • macOS Authorizations DB & Authd
        • macOS SIP
        • macOS TCC
          • macOS Apple Events
          • macOS TCC Bypasses
            • macOS Apple Scripts
          • macOS TCC Payloads
        • macOS Dangerous Entitlements & TCC perms
        • macOS - AMFI - AppleMobileFileIntegrity
        • macOS MACF - Mandatory Access Control Framework
        • macOS Code Signing
        • macOS FS Tricks
          • macOS xattr-acls extra stuff
      • macOS Users & External Accounts
    • macOS Red Teaming
      • macOS MDM
        • Enrolling Devices in Other Organisations
        • macOS Serial Number
      • macOS Keychain
    • macOS Useful Commands
    • macOS Auto Start
  • 🪟Windows Hardening
    • Checklist - Local Windows Privilege Escalation
    • Windows Local Privilege Escalation
      • Abusing Tokens
      • Access Tokens
      • ACLs - DACLs/SACLs/ACEs
      • AppendData/AddSubdirectory permission over service registry
      • Create MSI with WIX
      • COM Hijacking
      • Dll Hijacking
        • Writable Sys Path +Dll Hijacking Privesc
      • DPAPI - Extracting Passwords
      • From High Integrity to SYSTEM with Name Pipes
      • Integrity Levels
      • JuicyPotato
      • Leaked Handle Exploitation
      • MSI Wrapper
      • Named Pipe Client Impersonation
      • Privilege Escalation with Autoruns
      • RoguePotato, PrintSpoofer, SharpEfsPotato, GodPotato
      • SeDebug + SeImpersonate copy token
      • SeImpersonate from High To System
      • Windows C Payloads
    • Active Directory Methodology
      • Abusing Active Directory ACLs/ACEs
        • Shadow Credentials
      • AD Certificates
        • AD CS Account Persistence
        • AD CS Domain Escalation
        • AD CS Domain Persistence
        • AD CS Certificate Theft
      • AD information in printers
      • AD DNS Records
      • ASREPRoast
      • BloodHound & Other AD Enum Tools
      • Constrained Delegation
      • Custom SSP
      • DCShadow
      • DCSync
      • Diamond Ticket
      • DSRM Credentials
      • External Forest Domain - OneWay (Inbound) or bidirectional
      • External Forest Domain - One-Way (Outbound)
      • Golden Ticket
      • Kerberoast
      • Kerberos Authentication
      • Kerberos Double Hop Problem
      • LAPS
      • MSSQL AD Abuse
      • Over Pass the Hash/Pass the Key
      • Pass the Ticket
      • Password Spraying / Brute Force
      • PrintNightmare
      • Force NTLM Privileged Authentication
      • Privileged Groups
      • RDP Sessions Abuse
      • Resource-based Constrained Delegation
      • Security Descriptors
      • SID-History Injection
      • Silver Ticket
      • Skeleton Key
      • Unconstrained Delegation
    • Windows Security Controls
      • UAC - User Account Control
    • NTLM
      • Places to steal NTLM creds
    • Lateral Movement
      • AtExec / SchtasksExec
      • DCOM Exec
      • PsExec/Winexec/ScExec
      • SmbExec/ScExec
      • WinRM
      • WmiExec
    • Pivoting to the Cloud
    • Stealing Windows Credentials
      • Windows Credentials Protections
      • Mimikatz
      • WTS Impersonator
    • Basic Win CMD for Pentesters
    • Basic PowerShell for Pentesters
      • PowerView/SharpView
    • Antivirus (AV) Bypass
  • 📱Mobile Pentesting
    • Android APK Checklist
    • Android Applications Pentesting
      • Android Applications Basics
      • Android Task Hijacking
      • ADB Commands
      • APK decompilers
      • AVD - Android Virtual Device
      • Bypass Biometric Authentication (Android)
      • content:// protocol
      • Drozer Tutorial
        • Exploiting Content Providers
      • Exploiting a debuggeable application
      • Frida Tutorial
        • Frida Tutorial 1
        • Frida Tutorial 2
        • Frida Tutorial 3
        • Objection Tutorial
      • Google CTF 2018 - Shall We Play a Game?
      • Install Burp Certificate
      • Intent Injection
      • Make APK Accept CA Certificate
      • Manual DeObfuscation
      • React Native Application
      • Reversing Native Libraries
      • Smali - Decompiling/[Modifying]/Compiling
      • Spoofing your location in Play Store
      • Tapjacking
      • Webview Attacks
    • iOS Pentesting Checklist
    • iOS Pentesting
      • iOS App Extensions
      • iOS Basics
      • iOS Basic Testing Operations
      • iOS Burp Suite Configuration
      • iOS Custom URI Handlers / Deeplinks / Custom Schemes
      • iOS Extracting Entitlements From Compiled Application
      • iOS Frida Configuration
      • iOS Hooking With Objection
      • iOS Protocol Handlers
      • iOS Serialisation and Encoding
      • iOS Testing Environment
      • iOS UIActivity Sharing
      • iOS Universal Links
      • iOS UIPasteboard
      • iOS WebViews
    • Cordova Apps
    • Xamarin Apps
  • 👽Network Services Pentesting
    • Pentesting JDWP - Java Debug Wire Protocol
    • Pentesting Printers
    • Pentesting SAP
    • Pentesting VoIP
      • Basic VoIP Protocols
        • SIP (Session Initiation Protocol)
    • Pentesting Remote GdbServer
    • 7/tcp/udp - Pentesting Echo
    • 21 - Pentesting FTP
      • FTP Bounce attack - Scan
      • FTP Bounce - Download 2ºFTP file
    • 22 - Pentesting SSH/SFTP
    • 23 - Pentesting Telnet
    • 25,465,587 - Pentesting SMTP/s
      • SMTP Smuggling
      • SMTP - Commands
    • 43 - Pentesting WHOIS
    • 49 - Pentesting TACACS+
    • 53 - Pentesting DNS
    • 69/UDP TFTP/Bittorrent-tracker
    • 79 - Pentesting Finger
    • 80,443 - Pentesting Web Methodology
      • 403 & 401 Bypasses
      • AEM - Adobe Experience Cloud
      • Angular
      • Apache
      • Artifactory Hacking guide
      • Bolt CMS
      • Buckets
        • Firebase Database
      • CGI
      • DotNetNuke (DNN)
      • Drupal
        • Drupal RCE
      • Electron Desktop Apps
        • Electron contextIsolation RCE via preload code
        • Electron contextIsolation RCE via Electron internal code
        • Electron contextIsolation RCE via IPC
      • Flask
      • NodeJS Express
      • Git
      • Golang
      • GWT - Google Web Toolkit
      • Grafana
      • GraphQL
      • H2 - Java SQL database
      • IIS - Internet Information Services
      • ImageMagick Security
      • JBOSS
      • Jira & Confluence
      • Joomla
      • JSP
      • Laravel
      • Moodle
      • Nginx
      • NextJS
      • PHP Tricks
        • PHP - Useful Functions & disable_functions/open_basedir bypass
          • disable_functions bypass - php-fpm/FastCGI
          • disable_functions bypass - dl function
          • disable_functions bypass - PHP 7.0-7.4 (*nix only)
          • disable_functions bypass - Imagick <= 3.3.0 PHP >= 5.4 Exploit
          • disable_functions - PHP 5.x Shellshock Exploit
          • disable_functions - PHP 5.2.4 ionCube extension Exploit
          • disable_functions bypass - PHP <= 5.2.9 on windows
          • disable_functions bypass - PHP 5.2.4 and 5.2.5 PHP cURL
          • disable_functions bypass - PHP safe_mode bypass via proc_open() and custom environment Exploit
          • disable_functions bypass - PHP Perl Extension Safe_mode Bypass Exploit
          • disable_functions bypass - PHP 5.2.3 - Win32std ext Protections Bypass
          • disable_functions bypass - PHP 5.2 - FOpen Exploit
          • disable_functions bypass - via mem
          • disable_functions bypass - mod_cgi
          • disable_functions bypass - PHP 4 >= 4.2.0, PHP 5 pcntl_exec
        • PHP - RCE abusing object creation: new $_GET["a"]($_GET["b"])
        • PHP SSRF
      • PrestaShop
      • Python
      • Rocket Chat
      • Special HTTP headers
      • Source code Review / SAST Tools
      • Spring Actuators
      • Symfony
      • Tomcat
        • Basic Tomcat Info
      • Uncovering CloudFlare
      • VMWare (ESX, VCenter...)
      • Web API Pentesting
      • WebDav
      • Werkzeug / Flask Debug
      • Wordpress
    • 88tcp/udp - Pentesting Kerberos
      • Harvesting tickets from Windows
      • Harvesting tickets from Linux
    • 110,995 - Pentesting POP
    • 111/TCP/UDP - Pentesting Portmapper
    • 113 - Pentesting Ident
    • 123/udp - Pentesting NTP
    • 135, 593 - Pentesting MSRPC
    • 137,138,139 - Pentesting NetBios
    • 139,445 - Pentesting SMB
      • rpcclient enumeration
    • 143,993 - Pentesting IMAP
    • 161,162,10161,10162/udp - Pentesting SNMP
      • Cisco SNMP
      • SNMP RCE
    • 194,6667,6660-7000 - Pentesting IRC
    • 264 - Pentesting Check Point FireWall-1
    • 389, 636, 3268, 3269 - Pentesting LDAP
    • 500/udp - Pentesting IPsec/IKE VPN
    • 502 - Pentesting Modbus
    • 512 - Pentesting Rexec
    • 513 - Pentesting Rlogin
    • 514 - Pentesting Rsh
    • 515 - Pentesting Line Printer Daemon (LPD)
    • 548 - Pentesting Apple Filing Protocol (AFP)
    • 554,8554 - Pentesting RTSP
    • 623/UDP/TCP - IPMI
    • 631 - Internet Printing Protocol(IPP)
    • 700 - Pentesting EPP
    • 873 - Pentesting Rsync
    • 1026 - Pentesting Rusersd
    • 1080 - Pentesting Socks
    • 1098/1099/1050 - Pentesting Java RMI - RMI-IIOP
    • 1414 - Pentesting IBM MQ
    • 1433 - Pentesting MSSQL - Microsoft SQL Server
      • Types of MSSQL Users
    • 1521,1522-1529 - Pentesting Oracle TNS Listener
    • 1723 - Pentesting PPTP
    • 1883 - Pentesting MQTT (Mosquitto)
    • 2049 - Pentesting NFS Service
    • 2301,2381 - Pentesting Compaq/HP Insight Manager
    • 2375, 2376 Pentesting Docker
    • 3128 - Pentesting Squid
    • 3260 - Pentesting ISCSI
    • 3299 - Pentesting SAPRouter
    • 3306 - Pentesting Mysql
    • 3389 - Pentesting RDP
    • 3632 - Pentesting distcc
    • 3690 - Pentesting Subversion (svn server)
    • 3702/UDP - Pentesting WS-Discovery
    • 4369 - Pentesting Erlang Port Mapper Daemon (epmd)
    • 4786 - Cisco Smart Install
    • 4840 - OPC Unified Architecture
    • 5000 - Pentesting Docker Registry
    • 5353/UDP Multicast DNS (mDNS) and DNS-SD
    • 5432,5433 - Pentesting Postgresql
    • 5439 - Pentesting Redshift
    • 5555 - Android Debug Bridge
    • 5601 - Pentesting Kibana
    • 5671,5672 - Pentesting AMQP
    • 5800,5801,5900,5901 - Pentesting VNC
    • 5984,6984 - Pentesting CouchDB
    • 5985,5986 - Pentesting WinRM
    • 5985,5986 - Pentesting OMI
    • 6000 - Pentesting X11
    • 6379 - Pentesting Redis
    • 8009 - Pentesting Apache JServ Protocol (AJP)
    • 8086 - Pentesting InfluxDB
    • 8089 - Pentesting Splunkd
    • 8333,18333,38333,18444 - Pentesting Bitcoin
    • 9000 - Pentesting FastCGI
    • 9001 - Pentesting HSQLDB
    • 9042/9160 - Pentesting Cassandra
    • 9100 - Pentesting Raw Printing (JetDirect, AppSocket, PDL-datastream)
    • 9200 - Pentesting Elasticsearch
    • 10000 - Pentesting Network Data Management Protocol (ndmp)
    • 11211 - Pentesting Memcache
      • Memcache Commands
    • 15672 - Pentesting RabbitMQ Management
    • 24007,24008,24009,49152 - Pentesting GlusterFS
    • 27017,27018 - Pentesting MongoDB
    • 44134 - Pentesting Tiller (Helm)
    • 44818/UDP/TCP - Pentesting EthernetIP
    • 47808/udp - Pentesting BACNet
    • 50030,50060,50070,50075,50090 - Pentesting Hadoop
  • 🕸️Pentesting Web
    • Web Vulnerabilities Methodology
    • Reflecting Techniques - PoCs and Polygloths CheatSheet
      • Web Vulns List
    • 2FA/MFA/OTP Bypass
    • Account Takeover
    • Browser Extension Pentesting Methodology
      • BrowExt - ClickJacking
      • BrowExt - permissions & host_permissions
      • BrowExt - XSS Example
    • Bypass Payment Process
    • Captcha Bypass
    • Cache Poisoning and Cache Deception
      • Cache Poisoning via URL discrepancies
      • Cache Poisoning to DoS
    • Clickjacking
    • Client Side Template Injection (CSTI)
    • Client Side Path Traversal
    • Command Injection
    • Content Security Policy (CSP) Bypass
      • CSP bypass: self + 'unsafe-inline' with Iframes
    • Cookies Hacking
      • Cookie Tossing
      • Cookie Jar Overflow
      • Cookie Bomb
    • CORS - Misconfigurations & Bypass
    • CRLF (%0D%0A) Injection
    • CSRF (Cross Site Request Forgery)
    • Dangling Markup - HTML scriptless injection
      • SS-Leaks
    • Dependency Confusion
    • Deserialization
      • NodeJS - __proto__ & prototype Pollution
        • Client Side Prototype Pollution
        • Express Prototype Pollution Gadgets
        • Prototype Pollution to RCE
      • Java JSF ViewState (.faces) Deserialization
      • Java DNS Deserialization, GadgetProbe and Java Deserialization Scanner
      • Basic Java Deserialization (ObjectInputStream, readObject)
      • PHP - Deserialization + Autoload Classes
      • CommonsCollection1 Payload - Java Transformers to Rutime exec() and Thread Sleep
      • Basic .Net deserialization (ObjectDataProvider gadget, ExpandedWrapper, and Json.Net)
      • Exploiting __VIEWSTATE knowing the secrets
      • Exploiting __VIEWSTATE without knowing the secrets
      • Python Yaml Deserialization
      • JNDI - Java Naming and Directory Interface & Log4Shell
      • Ruby Class Pollution
    • Domain/Subdomain takeover
    • Email Injections
    • File Inclusion/Path traversal
      • phar:// deserialization
      • LFI2RCE via PHP Filters
      • LFI2RCE via Nginx temp files
      • LFI2RCE via PHP_SESSION_UPLOAD_PROGRESS
      • LFI2RCE via Segmentation Fault
      • LFI2RCE via phpinfo()
      • LFI2RCE Via temp file uploads
      • LFI2RCE via Eternal waiting
      • LFI2RCE Via compress.zlib + PHP_STREAM_PREFER_STUDIO + Path Disclosure
    • File Upload
      • PDF Upload - XXE and CORS bypass
    • Formula/CSV/Doc/LaTeX/GhostScript Injection
    • gRPC-Web Pentest
    • HTTP Connection Contamination
    • HTTP Connection Request Smuggling
    • HTTP Request Smuggling / HTTP Desync Attack
      • Browser HTTP Request Smuggling
      • Request Smuggling in HTTP/2 Downgrades
    • HTTP Response Smuggling / Desync
    • Upgrade Header Smuggling
    • hop-by-hop headers
    • IDOR
    • JWT Vulnerabilities (Json Web Tokens)
    • LDAP Injection
    • Login Bypass
      • Login bypass List
    • NoSQL injection
    • OAuth to Account takeover
    • Open Redirect
    • ORM Injection
    • Parameter Pollution
    • Phone Number Injections
    • PostMessage Vulnerabilities
      • Blocking main page to steal postmessage
      • Bypassing SOP with Iframes - 1
      • Bypassing SOP with Iframes - 2
      • Steal postmessage modifying iframe location
    • Proxy / WAF Protections Bypass
    • Race Condition
    • Rate Limit Bypass
    • Registration & Takeover Vulnerabilities
    • Regular expression Denial of Service - ReDoS
    • Reset/Forgotten Password Bypass
    • Reverse Tab Nabbing
    • SAML Attacks
      • SAML Basics
    • Server Side Inclusion/Edge Side Inclusion Injection
    • SQL Injection
      • MS Access SQL Injection
      • MSSQL Injection
      • MySQL injection
        • MySQL File priv to SSRF/RCE
      • Oracle injection
      • Cypher Injection (neo4j)
      • PostgreSQL injection
        • dblink/lo_import data exfiltration
        • PL/pgSQL Password Bruteforce
        • Network - Privesc, Port Scanner and NTLM chanllenge response disclosure
        • Big Binary Files Upload (PostgreSQL)
        • RCE with PostgreSQL Languages
        • RCE with PostgreSQL Extensions
      • SQLMap - CheatSheet
        • Second Order Injection - SQLMap
    • SSRF (Server Side Request Forgery)
      • URL Format Bypass
      • SSRF Vulnerable Platforms
      • Cloud SSRF
    • SSTI (Server Side Template Injection)
      • EL - Expression Language
      • Jinja2 SSTI
    • Timing Attacks
    • Unicode Injection
      • Unicode Normalization
    • UUID Insecurities
    • WebSocket Attacks
    • Web Tool - WFuzz
    • XPATH injection
    • XSLT Server Side Injection (Extensible Stylesheet Language Transformations)
    • XXE - XEE - XML External Entity
    • XSS (Cross Site Scripting)
      • Abusing Service Workers
      • Chrome Cache to XSS
      • Debugging Client Side JS
      • Dom Clobbering
      • DOM Invader
      • DOM XSS
      • Iframes in XSS, CSP and SOP
      • Integer Overflow
      • JS Hoisting
      • Misc JS Tricks & Relevant Info
      • PDF Injection
      • Server Side XSS (Dynamic PDF)
      • Shadow DOM
      • SOME - Same Origin Method Execution
      • Sniff Leak
      • Steal Info JS
      • XSS in Markdown
    • XSSI (Cross-Site Script Inclusion)
    • XS-Search/XS-Leaks
      • Connection Pool Examples
      • Connection Pool by Destination Example
      • Cookie Bomb + Onerror XS Leak
      • URL Max Length - Client Side
      • performance.now example
      • performance.now + Force heavy task
      • Event Loop Blocking + Lazy images
      • JavaScript Execution XS Leak
      • CSS Injection
        • CSS Injection Code
    • Iframe Traps
  • ⛈️Cloud Security
    • Pentesting Kubernetes
    • Pentesting Cloud (AWS, GCP, Az...)
    • Pentesting CI/CD (Github, Jenkins, Terraform...)
  • 😎Hardware/Physical Access
    • Physical Attacks
    • Escaping from KIOSKs
    • Firmware Analysis
      • Bootloader testing
      • Firmware Integrity
  • 🎯Binary Exploitation
    • Basic Stack Binary Exploitation Methodology
      • ELF Basic Information
      • Exploiting Tools
        • PwnTools
    • Stack Overflow
      • Pointer Redirecting
      • Ret2win
        • Ret2win - arm64
      • Stack Shellcode
        • Stack Shellcode - arm64
      • Stack Pivoting - EBP2Ret - EBP chaining
      • Uninitialized Variables
    • ROP - Return Oriented Programing
      • BROP - Blind Return Oriented Programming
      • Ret2csu
      • Ret2dlresolve
      • Ret2esp / Ret2reg
      • Ret2lib
        • Leaking libc address with ROP
          • Leaking libc - template
        • One Gadget
        • Ret2lib + Printf leak - arm64
      • Ret2syscall
        • Ret2syscall - ARM64
      • Ret2vDSO
      • SROP - Sigreturn-Oriented Programming
        • SROP - ARM64
    • Array Indexing
    • Integer Overflow
    • Format Strings
      • Format Strings - Arbitrary Read Example
      • Format Strings Template
    • Libc Heap
      • Bins & Memory Allocations
      • Heap Memory Functions
        • free
        • malloc & sysmalloc
        • unlink
        • Heap Functions Security Checks
      • Use After Free
        • First Fit
      • Double Free
      • Overwriting a freed chunk
      • Heap Overflow
      • Unlink Attack
      • Fast Bin Attack
      • Unsorted Bin Attack
      • Large Bin Attack
      • Tcache Bin Attack
      • Off by one overflow
      • House of Spirit
      • House of Lore | Small bin Attack
      • House of Einherjar
      • House of Force
      • House of Orange
      • House of Rabbit
      • House of Roman
    • Common Binary Exploitation Protections & Bypasses
      • ASLR
        • Ret2plt
        • Ret2ret & Reo2pop
      • CET & Shadow Stack
      • Libc Protections
      • Memory Tagging Extension (MTE)
      • No-exec / NX
      • PIE
        • BF Addresses in the Stack
      • Relro
      • Stack Canaries
        • BF Forked & Threaded Stack Canaries
        • Print Stack Canary
    • Write What Where 2 Exec
      • WWW2Exec - atexit()
      • WWW2Exec - .dtors & .fini_array
      • WWW2Exec - GOT/PLT
      • WWW2Exec - __malloc_hook & __free_hook
    • Common Exploiting Problems
    • Windows Exploiting (Basic Guide - OSCP lvl)
    • iOS Exploiting
  • 🔩Reversing
    • Reversing Tools & Basic Methods
      • Angr
        • Angr - Examples
      • Z3 - Satisfiability Modulo Theories (SMT)
      • Cheat Engine
      • Blobrunner
    • Common API used in Malware
    • Word Macros
  • 🔮Crypto & Stego
    • Cryptographic/Compression Algorithms
      • Unpacking binaries
    • Certificates
    • Cipher Block Chaining CBC-MAC
    • Crypto CTFs Tricks
    • Electronic Code Book (ECB)
    • Hash Length Extension Attack
    • Padding Oracle
    • RC4 - Encrypt&Decrypt
    • Stego Tricks
    • Esoteric languages
    • Blockchain & Crypto Currencies
  • 🦂C2
    • Salseo
    • ICMPsh
    • Cobalt Strike
  • ✍️TODO
    • Other Big References
    • Rust Basics
    • More Tools
    • MISC
    • Pentesting DNS
    • Hardware Hacking
      • I2C
      • UART
      • Radio
      • JTAG
      • SPI
    • Industrial Control Systems Hacking
      • Modbus Protocol
    • Radio Hacking
      • Pentesting RFID
      • Infrared
      • Sub-GHz RF
      • iButton
      • Flipper Zero
        • FZ - NFC
        • FZ - Sub-GHz
        • FZ - Infrared
        • FZ - iButton
        • FZ - 125kHz RFID
      • Proxmark 3
      • FISSURE - The RF Framework
      • Low-Power Wide Area Network
      • Pentesting BLE - Bluetooth Low Energy
    • Industrial Control Systems Hacking
    • Test LLMs
    • LLM Training
      • 0. Basic LLM Concepts
      • 1. Tokenizing
      • 2. Data Sampling
      • 3. Token Embeddings
      • 4. Attention Mechanisms
      • 5. LLM Architecture
      • 6. Pre-training & Loading models
      • 7.0. LoRA Improvements in fine-tuning
      • 7.1. Fine-Tuning for Classification
      • 7.2. Fine-Tuning to follow instructions
    • Burp Suite
    • Other Web Tricks
    • Interesting HTTP
    • Android Forensics
    • TR-069
    • 6881/udp - Pentesting BitTorrent
    • Online Platforms with API
    • Stealing Sensitive Information Disclosure from a Web
    • Post Exploitation
    • Investment Terms
    • Cookies Policy
Powered by GitBook
On this page
  • Text Generation
  • Text Evaluation
  • Pre-Train Example
  • Functions to transform text <--> ids
  • Generate text functions
  • Loss functions
  • Loading Data
  • Sanity Checks
  • Select device for training & pre calculations
  • Training functions
  • Start training
  • Print training evolution
  • Save the model
  • Loading GPT2 weights
  • References
Edit on GitHub
  1. TODO
  2. LLM Training

6. Pre-training & Loading models

Previous5. LLM ArchitectureNext7.0. LoRA Improvements in fine-tuning

Last updated 7 months ago

Text Generation

In order to train a model we will need that model to be able to generate new tokens. Then we will compare the generated tokens with the expected ones in order to train the model into learning the tokens it needs to generate.

As in the previous examples we already predicted some tokens, it's possible to reuse that function for this purpose.

The goal of this sixth phase is very simple: Train the model from scratch. For this the previous LLM architecture will be used with some loops going over the data sets using the defined loss functions and optimizer to train all the parameters of the model.

Text Evaluation

In order to perform a correct training it's needed to measure check the predictions obtained for the expected token. The goal of the training is to maximize the likelihood of the correct token, which involves increasing its probability relative to other tokens.

In order to maximize the probability of the correct token, the weights of the model must be modified to that probability is maximised. The updates of the weights is done via backpropagation. This requires a loss function to maximize. In this case, the function will be the difference between the performed prediction and the desired one.

However, instead of working with the raw predictions, it will work with a logarithm with base n. So if the current prediction of the expected token was 7.4541e-05, the natural logarithm (base e) of 7.4541e-05 is approximately -9.5042. Then, for each entry with a context length of 5 tokens for example, the model will need to predict 5 tokens, being the first 4 tokens the last one of the input and the fifth the predicted one. Therefore, for each entry we will have 5 predictions in that case (even if the first 4 ones were in the input the model doesn't know this) with 5 expected token and therefore 5 probabilities to maximize.

Therefore, after performing the natural logarithm to each prediction, the average is calculated, the minus symbol removed (this is called cross entropy loss) and thats the number to reduce as close to 0 as possible because the natural logarithm of 1 is 0:

Another way to measure how good the model is is called perplexity. Perplexity is a metric used to evaluate how well a probability model predicts a sample. In language modelling, it represents the model's uncertainty when predicting the next token in a sequence. For example, a perplexity value of 48725, means that when needed to predict a token it's unsure about which among 48,725 tokens in the vocabulary is the good one.

Pre-Train Example

Previous code used here but already explained in previous sections
"""
This is code explained before so it won't be exaplained
"""

import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Tokenize the entire text
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})

        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader_v1(txt, batch_size=4, max_length=256,
                         stride=128, shuffle=True, drop_last=True, num_workers=0):
    # Initialize the tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)

    return dataloader


class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec


class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift


class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))


class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )

    def forward(self, x):
        return self.layers(x)


class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed-forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        return x


class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits
# Download contents to train the data with
import os
import urllib.request

file_path = "the-verdict.txt"
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"

if not os.path.exists(file_path):
    with urllib.request.urlopen(url) as response:
        text_data = response.read().decode('utf-8')
    with open(file_path, "w", encoding="utf-8") as file:
        file.write(text_data)
else:
    with open(file_path, "r", encoding="utf-8") as file:
        text_data = file.read()

total_characters = len(text_data)
tokenizer = tiktoken.get_encoding("gpt2")
total_tokens = len(tokenizer.encode(text_data))

print("Data downloaded")
print("Characters:", total_characters)
print("Tokens:", total_tokens)

# Model initialization
GPT_CONFIG_124M = {
    "vocab_size": 50257,   # Vocabulary size
    "context_length": 256, # Shortened context length (orig: 1024)
    "emb_dim": 768,        # Embedding dimension
    "n_heads": 12,         # Number of attention heads
    "n_layers": 12,        # Number of layers
    "drop_rate": 0.1,      # Dropout rate
    "qkv_bias": False      # Query-key-value bias
}

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.eval()
print ("Model initialized")


# Functions to transform from tokens to ids and from to ids to tokens
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) # remove batch dimension
    return tokenizer.decode(flat.tolist())



# Define loss functions
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss


def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches


# Apply Train/validation ratio and create dataloaders
train_ratio = 0.90
split_idx = int(train_ratio * len(text_data))
train_data = text_data[:split_idx]
val_data = text_data[split_idx:]

torch.manual_seed(123)

train_loader = create_dataloader_v1(
    train_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

val_loader = create_dataloader_v1(
    val_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)


# Sanity checks
if total_tokens * (train_ratio) < GPT_CONFIG_124M["context_length"]:
    print("Not enough tokens for the training loader. "
          "Try to lower the `GPT_CONFIG_124M['context_length']` or "
          "increase the `training_ratio`")

if total_tokens * (1-train_ratio) < GPT_CONFIG_124M["context_length"]:
    print("Not enough tokens for the validation loader. "
          "Try to lower the `GPT_CONFIG_124M['context_length']` or "
          "decrease the `training_ratio`")

print("Train loader:")
for x, y in train_loader:
    print(x.shape, y.shape)

print("\nValidation loader:")
for x, y in val_loader:
    print(x.shape, y.shape)

train_tokens = 0
for input_batch, target_batch in train_loader:
    train_tokens += input_batch.numel()

val_tokens = 0
for input_batch, target_batch in val_loader:
    val_tokens += input_batch.numel()

print("Training tokens:", train_tokens)
print("Validation tokens:", val_tokens)
print("All tokens:", train_tokens + val_tokens)


# Indicate the device to use
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using {device} device.")

model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes



# Pre-calculate losses without starting yet
torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader

with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet
    train_loss = calc_loss_loader(train_loader, model, device)
    val_loss = calc_loss_loader(val_loader, model, device)

print("Training loss:", train_loss)
print("Validation loss:", val_loss)


# Functions to train the data
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer):
    # Initialize lists to track losses and tokens seen
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        
        for input_batch, target_batch in train_loader:
            optimizer.zero_grad() # Reset loss gradients from previous batch iteration
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward() # Calculate loss gradients
            optimizer.step() # Update model weights using loss gradients
            tokens_seen += input_batch.numel()
            global_step += 1

            # Optional evaluation step
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

        # Print a sample text after each epoch
        generate_and_print_sample(
            model, tokenizer, device, start_context
        )

    return train_losses, val_losses, track_tokens_seen


def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss


def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval()
    context_size = model.pos_emb.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer).to(device)
    with torch.no_grad():
        token_ids = generate_text(
            model=model, idx=encoded,
            max_new_tokens=50, context_size=context_size
        )
    decoded_text = token_ids_to_text(token_ids, tokenizer)
    print(decoded_text.replace("\n", " "))  # Compact print format
    model.train()


# Start training!
import time
start_time = time.time()

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

num_epochs = 10
train_losses, val_losses, tokens_seen = train_model_simple(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=5, eval_iter=5,
    start_context="Every effort moves you", tokenizer=tokenizer
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")



# Show graphics with the training process
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import math
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
    fig, ax1 = plt.subplots(figsize=(5, 3))
    ax1.plot(epochs_seen, train_losses, label="Training loss")
    ax1.plot(
        epochs_seen, val_losses, linestyle="-.", label="Validation loss"
    )
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.legend(loc="upper right")
    ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax2 = ax1.twiny()
    ax2.plot(tokens_seen, train_losses, alpha=0)
    ax2.set_xlabel("Tokens seen")
    fig.tight_layout()
    plt.show()

    # Compute perplexity from the loss values
    train_ppls = [math.exp(loss) for loss in train_losses]
    val_ppls = [math.exp(loss) for loss in val_losses]
    # Plot perplexity over tokens seen
    plt.figure()
    plt.plot(tokens_seen, train_ppls, label='Training Perplexity')
    plt.plot(tokens_seen, val_ppls, label='Validation Perplexity')
    plt.xlabel('Tokens Seen')
    plt.ylabel('Perplexity')
    plt.title('Perplexity over Training')
    plt.legend()
    plt.show()

epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)


torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    },
"/tmp/model_and_optimizer.pth"
)

Let's see an explanation step by step

Functions to transform text <--> ids

These are some simple functions that can be used to transform from texts from the vocabulary to ids and backwards. This is needed at the begging of the handling of the text and at the end fo the predictions:

# Functions to transform from tokens to ids and from to ids to tokens
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0) # remove batch dimension
    return tokenizer.decode(flat.tolist())

Generate text functions

In a previos section a function that just got the most probable token after getting the logits. However, this will mean that for each entry the same output is always going to be generated which makes it very deterministic.

The following generate_text function, will apply the top-k , temperature and multinomial concepts.

  • The top-k means that we will start reducing to -inf all the probabilities of all the tokens expect of the top k tokens. So, if k=3, before making a decision only the 3 most probably tokens will have a probability different from -inf.

  • The temperature means that every probability will be divided by the temperature value. A value of 0.1 will improve the highest probability compared with the lowest one, while a temperature of 5 for example will make it more flat. This helps to improve to variation in responses we would like the LLM to have.

  • After applying the temperature, a softmax function is applied again to make all the reminding tokens have a total probability of 1.

  • Finally, instead of choosing the token with the biggest probability, the function multinomial is applied to predict the next token according to the final probabilities. So if token 1 had a 70% of probabilities, token 2 a 20% and token 3 a 10%, 70% of the times token 1 will be selected, 20% of the times it will be token 2 and 10% of the times will be 10%.

# Generate text function
def generate_text(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):

    # For-loop is the same as before: Get logits, and only focus on last time step
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        # New: Filter logits with top_k sampling
        if top_k is not None:
            # Keep only top_k values
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)

        # New: Apply temperature scaling
        if temperature > 0.0:
            logits = logits / temperature

            # Apply softmax to get probabilities
            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)

            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)

        # Otherwise same as before: get idx of the vocab entry with the highest logits value
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)

        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified
            break

        # Same as before: append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

    return idx

Then, only those words of the vocabulary will be considered according to their relative probabilities

This allows to not need to select a number of k samples, as the optimal k might be different on each case, but only a threshold.

Note that this improvement isn't included in the previous code.

Another way to improve the generated text is by using Beam search instead of the greedy search sued in this example. Unlike greedy search, which selects the most probable next word at each step and builds a single sequence, beam search keeps track of the top 𝑘 k highest-scoring partial sequences (called "beams") at each step. By exploring multiple possibilities simultaneously, it balances efficiency and quality, increasing the chances of finding a better overall sequence that might be missed by the greedy approach due to early, suboptimal choices.

Note that this improvement isn't included in the previous code.

Loss functions

The calc_loss_batch function calculates the cross entropy of the a prediction of a single batch. The calc_loss_loader gets the cross entropy of all the batches and calculates the average cross entropy.

# Define loss functions
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
    return loss

def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0.
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        # Reduce the number of batches to match the total number of batches in the data loader
        # if num_batches exceeds the number of batches in the data loader
        num_batches = min(num_batches, len(data_loader))
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches

Gradient clipping is a technique used to enhance training stability in large neural networks by setting a maximum threshold for gradient magnitudes. When gradients exceed this predefined max_norm, they are scaled down proportionally to ensure that updates to the model’s parameters remain within a manageable range, preventing issues like exploding gradients and ensuring more controlled and stable training.

Note that this improvement isn't included in the previous code.

Check the following example:

Loading Data

The functions create_dataloader_v1 and create_dataloader_v1 were already discussed in a previous section.

From here note how it's defined that 90% of the text is going to be used for training while the 10% will be used for validation and both sets are stored in 2 different data loaders. Note that some times part of the data set is also left for a testing set to evaluate better the performance of the model.

Both data loaders are using the same batch size, maximum length and stride and num workers (0 in this case). The main differences are the data used by each, and the the validators is not dropping the last neither shuffling the data is it's not needed for validation purposes.

Also the fact that stride is as big as the context length, means that there won't be overlapping between contexts used to train the data (reduces overfitting but also the training data set).

Moreover, note that the batch size in this case it 2 to divide the data in 2 batches, the main goal of this is to allow parallel processing and reduce the consumption per batch.

train_ratio = 0.90
split_idx = int(train_ratio * len(text_data))
train_data = text_data[:split_idx]
val_data = text_data[split_idx:]

torch.manual_seed(123)

train_loader = create_dataloader_v1(
    train_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

val_loader = create_dataloader_v1(
    val_data,
    batch_size=2,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)

Sanity Checks

The goal is to check there are enough tokens for training, shapes are the expected ones and get some info about the number of tokens used for training and for validation:

# Sanity checks
if total_tokens * (train_ratio) < GPT_CONFIG_124M["context_length"]:
    print("Not enough tokens for the training loader. "
          "Try to lower the `GPT_CONFIG_124M['context_length']` or "
          "increase the `training_ratio`")

if total_tokens * (1-train_ratio) < GPT_CONFIG_124M["context_length"]:
    print("Not enough tokens for the validation loader. "
          "Try to lower the `GPT_CONFIG_124M['context_length']` or "
          "decrease the `training_ratio`")

print("Train loader:")
for x, y in train_loader:
    print(x.shape, y.shape)

print("\nValidation loader:")
for x, y in val_loader:
    print(x.shape, y.shape)

train_tokens = 0
for input_batch, target_batch in train_loader:
    train_tokens += input_batch.numel()

val_tokens = 0
for input_batch, target_batch in val_loader:
    val_tokens += input_batch.numel()

print("Training tokens:", train_tokens)
print("Validation tokens:", val_tokens)
print("All tokens:", train_tokens + val_tokens)

Select device for training & pre calculations

The following code just select the device to use and calculates a training loss and validation loss (without having trained anything yet) as a starting point.

# Indicate the device to use

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using {device} device.")

model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes

# Pre-calculate losses without starting yet
torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader

with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet
    train_loss = calc_loss_loader(train_loader, model, device)
    val_loss = calc_loss_loader(val_loader, model, device)

print("Training loss:", train_loss)
print("Validation loss:", val_loss)

Training functions

The function generate_and_print_sample will just get a context and generate some tokens in order to get a feeling about how good is the model at that point. This is called by train_model_simple on each step.

The function evaluate_model is called as frequently as indicate to the training function and it's used to measure the train loss and the validation loss at that point in the model training.

Then the big function train_model_simple is the one that actually train the model. It expects:

  • The train data loader (with the data already separated and prepared for training)

  • The validator loader

  • The optimizer to use during training: This is the function that will use the gradients and will update the parameters to reduce the loss. In this case, as you will see, AdamW is used, but there are many more.

    • optimizer.zero_grad() is called to reset the gradients on each round to not accumulate them.

    • The lr param is the learning rate which determines the size of the steps taken during the optimization process when updating the model's parameters. A smaller learning rate means the optimizer makes smaller updates to the weights, which can lead to more precise convergence but might slow down training. A larger learning rate can speed up training but risks overshooting the minimum of the loss function (jump over the point where the loss function is minimized).

    • Weight Decay modifies the Loss Calculation step by adding an extra term that penalizes large weights. This encourages the optimizer to find solutions with smaller weights, balancing between fitting the data well and keeping the model simple preventing overfitting in machine learning models by discouraging the model from assigning too much importance to any single feature.

      • Traditional optimizers like SGD with L2 regularization couple weight decay with the gradient of the loss function. However, AdamW (a variant of Adam optimizer) decouples weight decay from the gradient update, leading to more effective regularization.

  • The device to use for training

  • The number of epochs: Number of times to go over the training data

  • The evaluation frequency: The frequency to call evaluate_model

  • The evaluation iteration: The number of batches to use when evaluating the current state of the model when calling generate_and_print_sample

  • The start context: Which the starting sentence to use when calling generate_and_print_sample

  • The tokenizer

# Functions to train the data
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer):
    # Initialize lists to track losses and tokens seen
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        
        for input_batch, target_batch in train_loader:
            optimizer.zero_grad() # Reset loss gradients from previous batch iteration
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward() # Calculate loss gradients
            optimizer.step() # Update model weights using loss gradients
            tokens_seen += input_batch.numel()
            global_step += 1

            # Optional evaluation step
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")

        # Print a sample text after each epoch
        generate_and_print_sample(
            model, tokenizer, device, start_context
        )

    return train_losses, val_losses, track_tokens_seen


def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval() # Set in eval mode to avoid dropout
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train() # Back to training model applying all the configurations
    return train_loss, val_loss


def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval() # Set in eval mode to avoid dropout
    context_size = model.pos_emb.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer).to(device)
    with torch.no_grad():
        token_ids = generate_text(
            model=model, idx=encoded,
            max_new_tokens=50, context_size=context_size
        )
    decoded_text = token_ids_to_text(token_ids, tokenizer)
    print(decoded_text.replace("\n", " "))  # Compact print format
    model.train() # Back to training model applying all the configurations

To improve the learning rate there are a couple relevant techniques called linear warmup and cosine decay.

Linear warmup consist on define an initial learning rate and a maximum one and consistently update it after each epoch. This is because starting the training with smaller weight updates decreases the risk of the model encountering large, destabilizing updates during its training phase. Cosine decay is a technique that gradually reduces the learning rate following a half-cosine curve after the warmup phase, slowing weight updates to minimize the risk of overshooting the loss minima and ensure training stability in later phases.

Note that these improvements aren't included in the previous code.

Start training

import time
start_time = time.time()

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

num_epochs = 10
train_losses, val_losses, tokens_seen = train_model_simple(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=5, eval_iter=5,
    start_context="Every effort moves you", tokenizer=tokenizer
)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")

Print training evolution

With the following function it's possible to print the evolution of the model while it was being trained.

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import math
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
    fig, ax1 = plt.subplots(figsize=(5, 3))
    ax1.plot(epochs_seen, train_losses, label="Training loss")
    ax1.plot(
        epochs_seen, val_losses, linestyle="-.", label="Validation loss"
    )
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.legend(loc="upper right")
    ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax2 = ax1.twiny()
    ax2.plot(tokens_seen, train_losses, alpha=0)
    ax2.set_xlabel("Tokens seen")
    fig.tight_layout()
    plt.show()

    # Compute perplexity from the loss values
    train_ppls = [math.exp(loss) for loss in train_losses]
    val_ppls = [math.exp(loss) for loss in val_losses]
    # Plot perplexity over tokens seen
    plt.figure()
    plt.plot(tokens_seen, train_ppls, label='Training Perplexity')
    plt.plot(tokens_seen, val_ppls, label='Validation Perplexity')
    plt.xlabel('Tokens Seen')
    plt.ylabel('Perplexity')
    plt.title('Perplexity over Training')
    plt.legend()
    plt.show()

epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)

Save the model

It's possible to save the model + optimizer if you want to continue training later:

# Save the model and the optimizer for later training
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    },
"/tmp/model_and_optimizer.pth"
)
# Note that this model with the optimizer occupied close to 2GB

# Restore model and optimizer for training
checkpoint = torch.load("/tmp/model_and_optimizer.pth", map_location=device)

model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.train(); # Put in training mode

Or just the model if you are planing just on using it:

# Save the model
torch.save(model.state_dict(), "model.pth")

# Load it
model = GPTModel(GPT_CONFIG_124M)

model.load_state_dict(torch.load("model.pth", map_location=device))

model.eval() # Put in eval mode

Loading GPT2 weights

References

This is the initial code proposed in some times slightly modify

There is a common alternative to top-k called , also known as nucleus sampling, which instead of getting k samples with the most probability, it organizes all the resulting vocabulary by probabilities and sums them from the highest probability to the lowest until a threshold is reached.

There 2 quick scripts to load the GPT2 weights locally. For both you can clone the repository locally, then:

The script will download all the weights and transform the formats from OpenAI to the ones expected by our LLM. The script is also prepared with the needed configuration and with the prompt: "Every effort moves you"

The script allows you to load any of the GPT2 weights locally (just change the CHOOSE_MODEL var) and predict text from some prompts.

✍️
https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/01_main-chapter-code/ch05.ipynb
top-p
https://github.com/rasbt/LLMs-from-scratch
https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/01_main-chapter-code/gpt_generate.py
https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/02_alternative_weight_loading/weight-loading-hf-transformers.ipynb
https://www.manning.com/books/build-a-large-language-model-from-scratch
https://camo.githubusercontent.com/3c0ab9c55cefa10b667f1014b6c42df901fa330bb2bc9cea88885e784daec8ba/68747470733a2f2f73656261737469616e72617363686b612e636f6d2f696d616765732f4c4c4d732d66726f6d2d736372617463682d696d616765732f636830355f636f6d707265737365642f63726f73732d656e74726f70792e776562703f313233