18 Commits

Author SHA1 Message Date
96c21fffa9 separate depth from inputs 2026-03-28 08:44:22 -07:00
357e0f4a20 print best loss 2026-03-28 08:36:16 -07:00
31bfa208f8 include angles in history 2026-03-28 08:26:03 -07:00
1d09378bfd silence lint 2026-03-28 08:08:01 -07:00
bf2bf6d693 dropout first 2026-03-28 07:37:04 -07:00
a144ff1178 fix file name shenanigans 2026-03-27 19:33:19 -07:00
48f9657d0f don't hardcode map and bot 2026-03-27 16:46:02 -07:00
e38c0a92b4 remove chrono dep 2026-03-27 16:28:30 -07:00
148471dce1 simulate: add output_file argument 2026-03-27 16:28:30 -07:00
7bf439395b write model name based on num params 2026-03-27 16:16:26 -07:00
03f5eb5c13 tweak model 2026-03-27 16:04:03 -07:00
1e7bb6c4ce format 2026-03-27 15:57:32 -07:00
d8b0f9abbb rename GraphicsState to InputGenerator 2026-03-27 15:57:17 -07:00
fb8c6e2492 training options 2026-03-27 15:56:15 -07:00
18cad85b62 add cli args 2026-03-27 15:56:15 -07:00
b195a7eb95 split code into modules 2026-03-27 15:46:51 -07:00
4208090da0 add clap dep 2026-03-27 15:32:17 -07:00
e31b148f41 simulator 2026-03-27 15:29:32 -07:00
7 changed files with 912 additions and 479 deletions

231
Cargo.lock generated
View File

@@ -82,6 +82,56 @@ dependencies = [
"libc",
]
[[package]]
name = "anstream"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000"
[[package]]
name = "anstyle-parse"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.61.2",
]
[[package]]
name = "anyhow"
version = "1.0.102"
@@ -1038,9 +1088,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.2.57"
version = "1.2.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423"
checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1"
dependencies = [
"find-msvc-tools",
"jobserver",
@@ -1080,15 +1130,6 @@ dependencies = [
"rand_core 0.10.0",
]
[[package]]
name = "chrono"
version = "0.4.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0"
dependencies = [
"num-traits",
]
[[package]]
name = "cipher"
version = "0.4.4"
@@ -1110,6 +1151,46 @@ dependencies = [
"libloading",
]
[[package]]
name = "clap"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351"
dependencies = [
"clap_builder",
"clap_derive",
]
[[package]]
name = "clap_builder"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "clap_lex"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
[[package]]
name = "codespan-reporting"
version = "0.12.0"
@@ -1118,7 +1199,7 @@ checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
dependencies = [
"serde",
"termcolor",
"unicode-width 0.1.14",
"unicode-width 0.2.0",
]
[[package]]
@@ -1129,7 +1210,7 @@ checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681"
dependencies = [
"serde",
"termcolor",
"unicode-width 0.1.14",
"unicode-width 0.2.0",
]
[[package]]
@@ -1138,13 +1219,19 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
[[package]]
name = "colorchoice"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
[[package]]
name = "colored"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -2037,7 +2124,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -2241,7 +2328,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -2928,7 +3015,7 @@ dependencies = [
"log",
"presser",
"thiserror 2.0.18",
"windows 0.58.0",
"windows 0.62.2",
]
[[package]]
@@ -3405,6 +3492,12 @@ dependencies = [
"serde",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
[[package]]
name = "itertools"
version = "0.13.0"
@@ -3798,9 +3891,9 @@ dependencies = [
[[package]]
name = "mio"
version = "1.1.1"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc"
checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1"
dependencies = [
"libc",
"log",
@@ -3852,9 +3945,9 @@ dependencies = [
[[package]]
name = "naga"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85b4372fed0bd362d646d01b6926df0e837859ccc522fed720c395e0460f29c8"
checksum = "aa2630921705b9b01dcdd0b6864b9562ca3c1951eecd0f0c4f5f04f61e412647"
dependencies = [
"arrayvec",
"bit-set 0.9.1",
@@ -3977,7 +4070,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -4210,6 +4303,12 @@ version = "1.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
[[package]]
name = "once_cell_polyfill"
version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
[[package]]
name = "option-ext"
version = "0.2.0"
@@ -5063,7 +5162,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.12.1",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -5355,9 +5454,9 @@ dependencies = [
[[package]]
name = "simd-adler32"
version = "0.3.8"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214"
[[package]]
name = "simd_helpers"
@@ -5406,7 +5505,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [
"libc",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -5460,7 +5559,7 @@ name = "strafe-ai"
version = "0.1.0"
dependencies = [
"burn",
"chrono",
"clap",
"glam",
"pollster",
"strafesnet_common",
@@ -5469,7 +5568,7 @@ dependencies = [
"strafesnet_roblox_bot_file",
"strafesnet_roblox_bot_player",
"strafesnet_snf",
"wgpu 29.0.0",
"wgpu 29.0.1",
]
[[package]]
@@ -5498,7 +5597,7 @@ dependencies = [
"glam",
"id",
"strafesnet_common",
"wgpu 29.0.0",
"wgpu 29.0.1",
]
[[package]]
@@ -5535,7 +5634,7 @@ dependencies = [
"strafesnet_graphics",
"strafesnet_roblox_bot_file",
"thiserror 2.0.18",
"wgpu 29.0.0",
"wgpu 29.0.1",
]
[[package]]
@@ -5740,7 +5839,7 @@ dependencies = [
"getrandom 0.4.2",
"once_cell",
"rustix 1.1.4",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -6312,9 +6411,9 @@ dependencies = [
[[package]]
name = "unicode-segmentation"
version = "1.13.1"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da36089a805484bcccfffe0739803392c8298778a2d2f09febf76fac5ad9025b"
checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c"
[[package]]
name = "unicode-truncate"
@@ -6412,10 +6511,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "uuid"
version = "1.22.0"
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9"
dependencies = [
"getrandom 0.4.2",
"js-sys",
@@ -6692,9 +6797,9 @@ dependencies = [
[[package]]
name = "wgpu"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78f9f386699b1fb8b8a05bfe82169b24d151f05702d2905a0bf93bc454fcc825"
checksum = "72c239a9a747bbd379590985bac952c2e53cb19873f7072b3370c6a6a8e06837"
dependencies = [
"arrayvec",
"bitflags",
@@ -6705,7 +6810,7 @@ dependencies = [
"hashbrown 0.16.1",
"js-sys",
"log",
"naga 29.0.0",
"naga 29.0.1",
"parking_lot",
"portable-atomic",
"profiling",
@@ -6715,9 +6820,9 @@ dependencies = [
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"wgpu-core 29.0.0",
"wgpu-hal 29.0.0",
"wgpu-types 29.0.0",
"wgpu-core 29.0.1",
"wgpu-hal 29.0.1",
"wgpu-types 29.0.1",
]
[[package]]
@@ -6753,9 +6858,9 @@ dependencies = [
[[package]]
name = "wgpu-core"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7c34181b0acb8f98168f78f8e57ec66f57df5522b39143dbe5f2f45d7ca927c"
checksum = "1e80ac6cf1895df6342f87d975162108f9d98772a0d74bc404ab7304ac29469e"
dependencies = [
"arrayvec",
"bit-set 0.9.1",
@@ -6767,7 +6872,7 @@ dependencies = [
"hashbrown 0.16.1",
"indexmap",
"log",
"naga 29.0.0",
"naga 29.0.1",
"once_cell",
"parking_lot",
"portable-atomic",
@@ -6779,9 +6884,9 @@ dependencies = [
"wgpu-core-deps-apple 29.0.0",
"wgpu-core-deps-emscripten 29.0.0",
"wgpu-core-deps-windows-linux-android 29.0.0",
"wgpu-hal 29.0.0",
"wgpu-hal 29.0.1",
"wgpu-naga-bridge",
"wgpu-types 29.0.0",
"wgpu-types 29.0.1",
]
[[package]]
@@ -6799,7 +6904,7 @@ version = "29.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43acd053312501689cd92a01a9638d37f3e41a5fd9534875efa8917ee2d11ac0"
dependencies = [
"wgpu-hal 29.0.0",
"wgpu-hal 29.0.1",
]
[[package]]
@@ -6817,7 +6922,7 @@ version = "29.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef043bf135cc68b6f667c55ff4e345ce2b5924d75bad36a47921b0287ca4b24a"
dependencies = [
"wgpu-hal 29.0.0",
"wgpu-hal 29.0.1",
]
[[package]]
@@ -6835,7 +6940,7 @@ version = "29.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "725d5c006a8c02967b6d93ef04f6537ec4593313e330cfe86d9d3f946eb90f28"
dependencies = [
"wgpu-hal 29.0.0",
"wgpu-hal 29.0.1",
]
[[package]]
@@ -6888,9 +6993,9 @@ dependencies = [
[[package]]
name = "wgpu-hal"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "058b6047337cf323a4f092486443a9337f3d81325347e5d77deed7e563aeaedc"
checksum = "89a47aef47636562f3937285af4c44b4b5b404b46577471411cc5313a921da7e"
dependencies = [
"android_system_properties",
"arrayvec",
@@ -6911,7 +7016,7 @@ dependencies = [
"libc",
"libloading",
"log",
"naga 29.0.0",
"naga 29.0.1",
"ndk-sys",
"objc2",
"objc2-core-foundation",
@@ -6934,19 +7039,19 @@ dependencies = [
"wayland-sys",
"web-sys",
"wgpu-naga-bridge",
"wgpu-types 29.0.0",
"wgpu-types 29.0.1",
"windows 0.62.2",
"windows-core 0.62.2",
]
[[package]]
name = "wgpu-naga-bridge"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0b8e1e505095f24cb4a578f04b1421d456257dca7fac114d9d9dd3d978c34b8"
checksum = "7b4684f4410da0cf95a4cb63bb5edaac022461dedb6adf0b64d0d9b5f6890d51"
dependencies = [
"naga 29.0.0",
"wgpu-types 29.0.0",
"naga 29.0.1",
"wgpu-types 29.0.1",
]
[[package]]
@@ -6965,9 +7070,9 @@ dependencies = [
[[package]]
name = "wgpu-types"
version = "29.0.0"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d15ece45db77dd5451f11c0ce898334317ce8502d304a20454b531fdc0652fae"
checksum = "ec2675540fb1a5cfa5ef122d3d5f390e2c75711a0b946410f2d6ac3a0f77d1f6"
dependencies = [
"bitflags",
"bytemuck",
@@ -6999,7 +7104,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -7766,9 +7871,9 @@ dependencies = [
[[package]]
name = "zune-jpeg"
version = "0.5.14"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7a1c0af6e5d8d1363f4994b7a091ccf963d8b694f7da5b0b9cceb82da2c0a6"
checksum = "27bc9d5b815bc103f142aa054f561d9187d191692ec7c2d1e2b4737f8dbd7296"
dependencies = [
"zune-core",
]

View File

@@ -5,6 +5,7 @@ edition = "2024"
[dependencies]
burn = { version = "0.20.1", features = ["cuda", "autodiff"] }
clap = { version = "4.6.0", features = ["derive"] }
glam = "0.32.1"
pollster = "0.4.0"
wgpu = "29.0.0"
@@ -15,4 +16,3 @@ strafesnet_physics = { version = "=0.0.2-surf", registry = "strafesnet" }
strafesnet_roblox_bot_file = { version = "0.9.4", registry = "strafesnet" }
strafesnet_roblox_bot_player = { version = "=0.6.2-depth2", registry = "strafesnet" }
strafesnet_snf = { version = "0.4.0", registry = "strafesnet" }
chrono = { version = "0.4.44", default-features = false, features = ["now"] }

254
src/inference.rs Normal file
View File

@@ -0,0 +1,254 @@
#[derive(clap::Subcommand)]
pub enum Commands {
Simulate(SimulateSubcommand),
}
impl Commands {
pub fn run(self) {
match self {
Commands::Simulate(subcommand) => subcommand.run(),
}
}
}
#[derive(clap::Args)]
pub struct SimulateSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
model_file: std::path::PathBuf,
#[arg(long)]
output_file: Option<std::path::PathBuf>,
#[arg(long)]
map_file: std::path::PathBuf,
}
impl SimulateSubcommand {
fn run(self) {
let output_file = self.output_file.unwrap_or_else(|| {
let mut file_name = self
.model_file
.file_stem()
.unwrap()
.to_str()
.unwrap()
.to_owned();
file_name.push_str("_replay.snfb");
let mut path = self.model_file.clone();
path.set_file_name(file_name);
path
});
inference(
self.gpu_id.unwrap_or_default(),
self.model_file,
output_file,
self.map_file,
);
}
}
use burn::prelude::*;
use crate::inputs::InputGenerator;
use crate::net::{DEPTH_SIZE, InferenceBackend, Net, POSITION_HISTORY_SIZE};
use strafesnet_common::instruction::TimedInstruction;
use strafesnet_common::mouse::MouseState;
use strafesnet_common::physics::{
Instruction as PhysicsInputInstruction, ModeInstruction, MouseInstruction,
SetControlInstruction, Time as PhysicsTime,
};
use strafesnet_physics::physics::{PhysicsContext, PhysicsData, PhysicsState};
pub struct Recording {
instructions: Vec<TimedInstruction<PhysicsInputInstruction, PhysicsTime>>,
}
struct FrameState {
trajectory: strafesnet_physics::physics::Trajectory,
camera: strafesnet_physics::physics::PhysicsCamera,
}
impl FrameState {
fn pos(&self, time: PhysicsTime) -> glam::Vec3 {
self.trajectory
.extrapolated_position(time)
.map(Into::<f32>::into)
.to_array()
.into()
}
fn angles(&self) -> glam::Vec2 {
self.camera.simulate_move_angles(glam::IVec2::ZERO)
}
}
struct Session {
geometry_shared: PhysicsData,
simulation: PhysicsState,
recording: Recording,
}
impl Session {
fn get_frame_state(&self) -> FrameState {
FrameState {
trajectory: self.simulation.camera_trajectory(&self.geometry_shared),
camera: self.simulation.camera(),
}
}
fn run(&mut self, time: PhysicsTime, instruction: PhysicsInputInstruction) {
let instruction = TimedInstruction { time, instruction };
self.recording.instructions.push(instruction.clone());
PhysicsContext::run_input_instruction(
&mut self.simulation,
&self.geometry_shared,
instruction,
);
}
}
fn inference(
gpu_id: usize,
model_file: std::path::PathBuf,
output_file: std::path::PathBuf,
map_file: std::path::PathBuf,
) {
// pick device
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
// load model
let mut model: Net<InferenceBackend> = Net::init(&device);
model = model
.load_file(
model_file,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
&device,
)
.unwrap();
// load map
let map_file = std::fs::read(map_file).unwrap();
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
let modes = map.modes.clone().denormalize();
let mode = modes
.get_mode(strafesnet_common::gameplay_modes::ModeId::MAIN)
.unwrap();
let start_zone = map.models.get(mode.get_start().get() as usize).unwrap();
let start_offset = glam::Vec3::from_array(
start_zone
.transform
.translation
.map(|f| f.into())
.to_array(),
);
// setup graphics
let mut g = InputGenerator::new(&map);
// setup simulation
let mut session = Session {
geometry_shared: PhysicsData::new(&map),
simulation: PhysicsState::default(),
recording: Recording {
instructions: Vec::new(),
},
};
let mut time = PhysicsTime::ZERO;
// reset to start zone
session.run(time, PhysicsInputInstruction::Mode(ModeInstruction::Reset));
// session.run(
// time,
// PhysicsInputInstruction::Misc(MiscInstruction::SetSensitivity(?)),
// );
session.run(
time,
PhysicsInputInstruction::Mode(ModeInstruction::Restart(
strafesnet_common::gameplay_modes::ModeId::MAIN,
)),
);
// TEMP: turn mouse left
let mut mouse_pos = glam::ivec2(-5300, 0);
const STEP: PhysicsTime = PhysicsTime::from_millis(10);
let mut input_floats = Vec::new();
let mut depth_floats = Vec::new();
// setup agent-simulation feedback loop
for _ in 0..20 * 100 {
// generate inputs
let frame_state = session.get_frame_state();
g.generate_inputs(
frame_state.pos(time) - start_offset,
frame_state.angles(),
&mut input_floats,
&mut depth_floats,
);
// inference
let inputs = Tensor::from_data(
TensorData::new(input_floats.clone(), Shape::new([1, POSITION_HISTORY_SIZE])),
&device,
);
let depth = Tensor::from_data(
TensorData::new(depth_floats.clone(), Shape::new([1, DEPTH_SIZE])),
&device,
);
let outputs = model
.forward(inputs, depth)
.into_data()
.into_vec::<f32>()
.unwrap();
let &[
move_forward,
move_left,
move_back,
move_right,
jump,
mouse_dx,
mouse_dy,
] = outputs.as_slice()
else {
panic!()
};
macro_rules! set_control {
($control:ident,$output:expr) => {
session.run(
time,
PhysicsInputInstruction::SetControl(SetControlInstruction::$control(
0.5 < $output,
)),
);
};
}
set_control!(SetMoveForward, move_forward);
set_control!(SetMoveLeft, move_left);
set_control!(SetMoveBack, move_back);
set_control!(SetMoveRight, move_right);
set_control!(SetJump, jump);
mouse_pos += glam::vec2(mouse_dx, mouse_dy).round().as_ivec2();
let next_time = time + STEP;
session.run(
time,
PhysicsInputInstruction::Mouse(MouseInstruction::SetNextMouse(MouseState {
pos: mouse_pos,
time: next_time,
})),
);
time = next_time;
// clear
depth_floats.clear();
input_floats.clear();
}
let file = std::fs::File::create(output_file).unwrap();
strafesnet_snf::bot::write_bot(
std::io::BufWriter::new(file),
strafesnet_physics::VERSION.get(),
core::mem::take(&mut session.recording.instructions),
)
.unwrap();
}

174
src/inputs.rs Normal file
View File

@@ -0,0 +1,174 @@
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
use strafesnet_graphics::setup;
use crate::net::{POSITION_HISTORY, SIZE};
// bytes_per_row needs to be a multiple of 256.
const STRIDE_SIZE: u32 = (SIZE.x * size_of::<f32>() as u32).next_multiple_of(256);
pub struct InputGenerator {
device: wgpu::Device,
queue: wgpu::Queue,
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
graphics_texture_view: wgpu::TextureView,
output_staging_buffer: wgpu::Buffer,
texture_data: Vec<u8>,
position_history: Vec<(glam::Vec3, glam::Vec2)>,
}
impl InputGenerator {
pub fn new(map: &strafesnet_common::map::CompleteMap) -> Self {
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (device, queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&device, &queue, SIZE, FORMAT, LIMITS,
);
graphics.change_map(&device, &queue, map).unwrap();
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: SIZE.x,
height: SIZE.y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
),
..Default::default()
});
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE.y) as usize);
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let position_history = Vec::with_capacity(POSITION_HISTORY);
Self {
device,
queue,
graphics,
graphics_texture_view,
output_staging_buffer,
texture_data,
position_history,
}
}
pub fn generate_inputs(
&mut self,
pos: glam::Vec3,
angles: glam::Vec2,
inputs: &mut Vec<f32>,
depth: &mut Vec<f32>,
) {
// write position history to model inputs
if !self.position_history.is_empty() {
let camera = strafesnet_graphics::graphics::view_inv(pos, angles).inverse();
for &(pos, ang) in self.position_history.iter().rev() {
let relative_pos = camera.transform_vector3(pos);
let relative_ang = glam::vec2(angles.x - ang.x, ang.y);
inputs.extend_from_slice(&relative_pos.to_array());
inputs.extend_from_slice(&relative_ang.to_array());
}
}
// fill remaining history with zeroes
for _ in self.position_history.len()..POSITION_HISTORY {
inputs.extend_from_slice(&[0.0, 0.0, 0.0, 0.0, 0.0]);
}
// track position history
if self.position_history.len() < POSITION_HISTORY {
self.position_history.push((pos, angles));
} else {
self.position_history.rotate_left(1);
*self.position_history.last_mut().unwrap() = (pos, angles);
}
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("wgpu encoder"),
});
// render!
self.graphics
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
// copy the depth texture into ram
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: self.graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &self.output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
// This needs to be a multiple of 256.
bytes_per_row: Some(STRIDE_SIZE),
rows_per_image: Some(SIZE.y),
},
},
wgpu::Extent3d {
width: SIZE.x,
height: SIZE.y,
depth_or_array_layers: 1,
},
);
self.queue.submit([encoder.finish()]);
// map buffer
let buffer_slice = self.output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
self.device
.poll(wgpu::PollType::wait_indefinitely())
.unwrap();
receiver.recv().unwrap().unwrap();
// copy texture inside a scope so the mapped view gets dropped
{
let view = buffer_slice.get_mapped_range();
self.texture_data.extend_from_slice(&view[..]);
}
self.output_staging_buffer.unmap();
// discombolulate stride
for y in 0..SIZE.y {
depth.extend(
self.texture_data[(STRIDE_SIZE * y) as usize
..(STRIDE_SIZE * y + SIZE.x * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
)
}
self.texture_data.clear();
}
}

View File

@@ -1,423 +1,30 @@
use burn::backend::Autodiff;
use burn::nn::loss::{MseLoss, Reduction};
use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig, Relu};
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::prelude::*;
use clap::{Parser, Subcommand};
type InferenceBackend = burn::backend::Cuda<f32>;
type TrainingBackend = Autodiff<InferenceBackend>;
mod inference;
mod inputs;
mod net;
mod training;
const LIMITS: wgpu::Limits = wgpu::Limits::defaults();
const FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Rgba8UnormSrgb;
use strafesnet_graphics::setup;
use strafesnet_roblox_bot_file::v0;
const SIZE: glam::UVec2 = glam::uvec2(64, 36);
const POSITION_HISTORY: usize = 4;
const INPUT: usize = (SIZE.x * SIZE.y) as usize + POSITION_HISTORY * 3;
const HIDDEN: [usize; 2] = [INPUT >> 3, INPUT >> 7];
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
const OUTPUT: usize = 7;
// bytes_per_row needs to be a multiple of 256.
const STRIDE_SIZE: u32 = (SIZE.x * size_of::<f32>() as u32).next_multiple_of(256);
#[derive(Module, Debug)]
struct Net<B: Backend> {
input: Linear<B>,
dropout: Dropout,
hidden: [Linear<B>; HIDDEN.len() - 1],
output: Linear<B>,
activation: Relu,
}
impl<B: Backend> Net<B> {
fn init(device: &B::Device) -> Self {
let mut it = HIDDEN.into_iter();
let mut last_size = it.next().unwrap();
let input = LinearConfig::new(INPUT, last_size).init(device);
let hidden = core::array::from_fn(|_| {
let size = it.next().unwrap();
let layer = LinearConfig::new(last_size, size).init(device);
last_size = size;
layer
});
let output = LinearConfig::new(last_size, OUTPUT).init(device);
let dropout = DropoutConfig::new(0.1).init();
Self {
input,
dropout,
hidden,
output,
activation: Relu::new(),
}
}
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.input.forward(input);
let x = self.dropout.forward(x);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
self.output.forward(x)
}
#[derive(Parser)]
#[command(author,version,about,long_about=None)]
#[command(propagate_version = true)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
struct GraphicsState {
device: wgpu::Device,
queue: wgpu::Queue,
graphics: strafesnet_roblox_bot_player::graphics::Graphics,
graphics_texture_view: wgpu::TextureView,
output_staging_buffer: wgpu::Buffer,
texture_data: Vec<u8>,
position_history: Vec<glam::Vec3>,
}
impl GraphicsState {
fn new(map: &strafesnet_common::map::CompleteMap) -> Self {
let desc = wgpu::InstanceDescriptor::new_without_display_handle_from_env();
let instance = wgpu::Instance::new(desc);
let (device, queue) = pollster::block_on(async {
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.unwrap();
setup::step4::request_device(&adapter, LIMITS)
.await
.unwrap()
});
let mut graphics = strafesnet_roblox_bot_player::graphics::Graphics::new(
&device, &queue, SIZE, FORMAT, LIMITS,
);
graphics.change_map(&device, &queue, map).unwrap();
let graphics_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("RGB texture"),
format: FORMAT,
size: wgpu::Extent3d {
width: SIZE.x,
height: SIZE.y,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let graphics_texture_view = graphics_texture.create_view(&wgpu::TextureViewDescriptor {
label: Some("RGB texture view"),
aspect: wgpu::TextureAspect::All,
usage: Some(
wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
),
..Default::default()
});
let texture_data = Vec::<u8>::with_capacity((STRIDE_SIZE * SIZE.y) as usize);
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output staging buffer"),
size: texture_data.capacity() as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let position_history = Vec::with_capacity(POSITION_HISTORY);
Self {
device,
queue,
graphics,
graphics_texture_view,
output_staging_buffer,
texture_data,
position_history,
}
}
fn generate_inputs(&mut self, pos: glam::Vec3, angles: glam::Vec2, inputs: &mut Vec<f32>) {
// write position history to model inputs
if !self.position_history.is_empty() {
let camera = strafesnet_graphics::graphics::view_inv(pos, angles).inverse();
for &pos in self.position_history.iter().rev() {
let relative_pos = camera.transform_vector3(pos);
inputs.extend_from_slice(&relative_pos.to_array());
}
}
// fill remaining history with zeroes
for _ in self.position_history.len()..POSITION_HISTORY {
inputs.extend_from_slice(&[0.0, 0.0, 0.0]);
}
// track position history
if self.position_history.len() < POSITION_HISTORY {
self.position_history.push(pos);
} else {
self.position_history.rotate_left(1);
*self.position_history.last_mut().unwrap() = pos;
}
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("wgpu encoder"),
});
// render!
self.graphics
.encode_commands(&mut encoder, &self.graphics_texture_view, pos, angles);
// copy the depth texture into ram
encoder.copy_texture_to_buffer(
wgpu::TexelCopyTextureInfo {
texture: self.graphics.depth_texture(),
mip_level: 0,
origin: wgpu::Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
wgpu::TexelCopyBufferInfo {
buffer: &self.output_staging_buffer,
layout: wgpu::TexelCopyBufferLayout {
offset: 0,
// This needs to be a multiple of 256.
bytes_per_row: Some(STRIDE_SIZE),
rows_per_image: Some(SIZE.y),
},
},
wgpu::Extent3d {
width: SIZE.x,
height: SIZE.y,
depth_or_array_layers: 1,
},
);
self.queue.submit([encoder.finish()]);
// map buffer
let buffer_slice = self.output_staging_buffer.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
self.device
.poll(wgpu::PollType::wait_indefinitely())
.unwrap();
receiver.recv().unwrap().unwrap();
// copy texture inside a scope so the mapped view gets dropped
{
let view = buffer_slice.get_mapped_range();
self.texture_data.extend_from_slice(&view[..]);
}
self.output_staging_buffer.unmap();
// discombolulate stride
for y in 0..SIZE.y {
inputs.extend(
self.texture_data[(STRIDE_SIZE * y) as usize
..(STRIDE_SIZE * y + SIZE.x * size_of::<f32>() as u32) as usize]
.chunks_exact(4)
.map(|b| 1.0 - 2.0 * f32::from_le_bytes(b.try_into().unwrap())),
)
}
self.texture_data.clear();
}
}
fn training() {
let gpu_id: usize = std::env::args()
.skip(1)
.next()
.map(|id| id.parse().unwrap())
.unwrap_or_default();
// load map
// load replay
// setup player
let map_file = include_bytes!("../files/bhop_marble_5692093612.snfm");
let bot_file = include_bytes!("../files/bhop_marble_7cf33a64-7120-4514-b9fa-4fe29d9523d.qbot");
// read files
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let world_offset = bot.world_offset();
let timelines = bot.timelines();
// setup simulation
// run progressively longer segments of the map, starting very close to the end of the run and working the starting time backwards until the ai can run the whole map
// set up graphics
let mut g = GraphicsState::new(&map);
// training data
let training_samples = timelines.input_events.len() - 1;
let input_size = INPUT * size_of::<f32>();
let mut inputs = Vec::with_capacity(input_size * training_samples);
let mut targets = Vec::with_capacity(OUTPUT * training_samples);
// generate all frames
println!("Generating {training_samples} frames of depth textures...");
let mut it = timelines.input_events.iter();
// grab mouse position from first frame, omitting one frame from the training data
let first = it.next().unwrap();
let mut last_mx = first.event.mouse_pos.x;
let mut last_my = first.event.mouse_pos.y;
for input_event in it {
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
let mouse_dy = input_event.event.mouse_pos.y - last_my;
last_mx = input_event.event.mouse_pos.x;
last_my = input_event.event.mouse_pos.y;
// set targets
targets.extend([
// MoveForward
input_event
.event
.game_controls
.contains(v0::GameControls::MoveForward) as i32 as f32,
// MoveLeft
input_event
.event
.game_controls
.contains(v0::GameControls::MoveLeft) as i32 as f32,
// MoveBack
input_event
.event
.game_controls
.contains(v0::GameControls::MoveBack) as i32 as f32,
// MoveRight
input_event
.event
.game_controls
.contains(v0::GameControls::MoveRight) as i32 as f32,
// Jump
input_event
.event
.game_controls
.contains(v0::GameControls::Jump) as i32 as f32,
mouse_dx,
mouse_dy,
]);
// find the closest output event to the input event time
let output_event_index = timelines
.output_events
.binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap());
let output_event = match output_event_index {
// found the exact same timestamp
Ok(output_event_index) => &timelines.output_events[output_event_index],
// found first index greater than the time.
// check this index and the one before and return the closest one
Err(insert_index) => timelines
.output_events
.get(insert_index)
.into_iter()
.chain(
insert_index
.checked_sub(1)
.and_then(|index| timelines.output_events.get(index)),
)
.min_by(|&e0, &e1| {
(e0.time - input_event.time)
.abs()
.partial_cmp(&(e1.time - input_event.time).abs())
.unwrap()
})
.unwrap(),
};
fn vec3(v: v0::Vector3) -> glam::Vec3 {
glam::vec3(v.x, v.y, v.z)
}
fn angles(a: v0::Vector3) -> glam::Vec2 {
glam::vec2(a.y, a.x)
}
let pos = vec3(output_event.event.position) - world_offset;
let angles = angles(output_event.event.angles);
g.generate_inputs(pos, angles, &mut inputs);
}
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
let mut model: Net<TrainingBackend> = Net::init(&device);
println!("Training model ({} parameters)", model.num_params());
let mut optim = AdamConfig::new().init();
let inputs = Tensor::from_data(
TensorData::new(inputs, Shape::new([training_samples, INPUT])),
&device,
);
let targets = Tensor::from_data(
TensorData::new(targets, Shape::new([training_samples, OUTPUT])),
&device,
);
const LEARNING_RATE: f64 = 0.001;
const EPOCHS: usize = 100000;
let mut best_model = model.clone();
let mut best_loss = f32::INFINITY;
for epoch in 0..EPOCHS {
let predictions = model.forward(inputs.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
let loss_scalar = loss.clone().into_scalar();
if epoch == 0 {
// kinda a fake print, but that's what is happening after this point
println!("Compiling optimized GPU kernels...");
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
// get the best model
if loss_scalar < best_loss {
best_loss = loss_scalar;
best_model = model.clone();
}
model = optim.step(LEARNING_RATE, model, grads);
if epoch % (EPOCHS >> 4) == 0 || epoch == EPOCHS - 1 {
// .clone().into_scalar() extracts the f32 value from a 1-element tensor.
println!(" epoch {:>5} | loss = {:.8}", epoch, loss_scalar);
}
}
let date_string = format!("{}_{}.model", chrono::Utc::now(), best_loss);
best_model
.save_file(
date_string,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
)
.unwrap();
}
fn inference() {
// load map
// setup simulation
// setup agent-simulation feedback loop
// go!
#[derive(Subcommand)]
enum Commands {
#[command(flatten)]
Roblox(inference::Commands),
#[command(flatten)]
Source(training::Commands),
}
fn main() {
training();
let cli = Cli::parse();
match cli.command {
Commands::Roblox(commands) => commands.run(),
Commands::Source(commands) => commands.run(),
}
}

63
src/net.rs Normal file
View File

@@ -0,0 +1,63 @@
use burn::backend::Autodiff;
use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig, Relu};
use burn::prelude::*;
pub type InferenceBackend = burn::backend::Cuda<f32>;
pub type TrainingBackend = Autodiff<InferenceBackend>;
pub const SIZE: glam::UVec2 = glam::uvec2(64, 36);
pub const DEPTH_SIZE: usize = (SIZE.x * SIZE.y) as usize;
pub const POSITION_HISTORY: usize = 10;
pub const POSITION_HISTORY_SIZE: usize = POSITION_HISTORY * 5;
const INPUT: usize = DEPTH_SIZE + POSITION_HISTORY_SIZE;
pub const HIDDEN: [usize; 3] = [INPUT >> 3, INPUT >> 5, INPUT >> 7];
// MoveForward
// MoveLeft
// MoveBack
// MoveRight
// Jump
// mouse_dx
// mouse_dy
pub const OUTPUT: usize = 7;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
input: Linear<B>,
dropout: Dropout,
hidden: [Linear<B>; HIDDEN.len() - 1],
output: Linear<B>,
activation: Relu,
}
impl<B: Backend> Net<B> {
pub fn init(device: &B::Device) -> Self {
let mut it = HIDDEN.into_iter();
let mut last_size = it.next().unwrap();
let input = LinearConfig::new(INPUT, last_size).init(device);
let hidden = core::array::from_fn(|_| {
let size = it.next().unwrap();
let layer = LinearConfig::new(last_size, size).init(device);
last_size = size;
layer
});
let output = LinearConfig::new(last_size, OUTPUT).init(device);
let dropout = DropoutConfig::new(0.1).init();
Self {
input,
dropout,
hidden,
output,
activation: Relu::new(),
}
}
pub fn forward(&self, input: Tensor<B, 2>, depth: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.dropout.forward(depth);
let x = Tensor::cat(vec![input, x], 1);
let x = self.input.forward(x);
let mut x = self.activation.forward(x);
for layer in &self.hidden {
x = layer.forward(x);
x = self.activation.forward(x);
}
self.output.forward(x)
}
}

230
src/training.rs Normal file
View File

@@ -0,0 +1,230 @@
#[derive(clap::Subcommand)]
pub enum Commands {
Train(TrainSubcommand),
}
impl Commands {
pub fn run(self) {
match self {
Commands::Train(subcommand) => subcommand.run(),
}
}
}
#[derive(clap::Args)]
pub struct TrainSubcommand {
#[arg(long)]
gpu_id: Option<usize>,
#[arg(long)]
epochs: Option<usize>,
#[arg(long)]
learning_rate: Option<f64>,
#[arg(long)]
map_file: std::path::PathBuf,
#[arg(long)]
bot_file: std::path::PathBuf,
}
impl TrainSubcommand {
fn run(self) {
training(
self.gpu_id.unwrap_or_default(),
self.epochs.unwrap_or(100_000),
self.learning_rate.unwrap_or(0.001),
self.map_file,
self.bot_file,
);
}
}
use burn::nn::loss::{MseLoss, Reduction};
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
use burn::prelude::*;
use crate::inputs::InputGenerator;
use crate::net::{DEPTH_SIZE, Net, OUTPUT, POSITION_HISTORY_SIZE, TrainingBackend};
use strafesnet_roblox_bot_file::v0;
fn training(
gpu_id: usize,
epochs: usize,
learning_rate: f64,
map_file: std::path::PathBuf,
bot_file: std::path::PathBuf,
) {
// read files
let map_file = std::fs::read(map_file).unwrap();
let bot_file = std::fs::read(bot_file).unwrap();
// load map
let map = strafesnet_snf::read_map(std::io::Cursor::new(map_file))
.unwrap()
.into_complete_map()
.unwrap();
// load replay
let timelines =
strafesnet_roblox_bot_file::v0::read_all_to_block(std::io::Cursor::new(bot_file)).unwrap();
let bot = strafesnet_roblox_bot_player::bot::CompleteBot::new(timelines).unwrap();
let world_offset = bot.world_offset();
let timelines = bot.timelines();
// set up graphics
let mut g = InputGenerator::new(&map);
// training data
let training_samples = timelines.input_events.len() - 1;
let input_size = POSITION_HISTORY_SIZE * size_of::<f32>();
let depth_size = DEPTH_SIZE * size_of::<f32>();
let mut inputs = Vec::with_capacity(input_size * training_samples);
let mut depth = Vec::with_capacity(depth_size * training_samples);
let mut targets = Vec::with_capacity(OUTPUT * training_samples);
// generate all frames
println!("Generating {training_samples} frames of depth textures...");
let mut it = timelines.input_events.iter();
// grab mouse position from first frame, omitting one frame from the training data
let first = it.next().unwrap();
let mut last_mx = first.event.mouse_pos.x;
let mut last_my = first.event.mouse_pos.y;
for input_event in it {
let mouse_dx = input_event.event.mouse_pos.x - last_mx;
let mouse_dy = input_event.event.mouse_pos.y - last_my;
last_mx = input_event.event.mouse_pos.x;
last_my = input_event.event.mouse_pos.y;
// set targets
targets.extend([
// MoveForward
input_event
.event
.game_controls
.contains(v0::GameControls::MoveForward) as i32 as f32,
// MoveLeft
input_event
.event
.game_controls
.contains(v0::GameControls::MoveLeft) as i32 as f32,
// MoveBack
input_event
.event
.game_controls
.contains(v0::GameControls::MoveBack) as i32 as f32,
// MoveRight
input_event
.event
.game_controls
.contains(v0::GameControls::MoveRight) as i32 as f32,
// Jump
input_event
.event
.game_controls
.contains(v0::GameControls::Jump) as i32 as f32,
mouse_dx,
mouse_dy,
]);
// find the closest output event to the input event time
let output_event_index = timelines
.output_events
.binary_search_by(|event| event.time.partial_cmp(&input_event.time).unwrap());
let output_event = match output_event_index {
// found the exact same timestamp
Ok(output_event_index) => &timelines.output_events[output_event_index],
// found first index greater than the time.
// check this index and the one before and return the closest one
Err(insert_index) => timelines
.output_events
.get(insert_index)
.into_iter()
.chain(
insert_index
.checked_sub(1)
.and_then(|index| timelines.output_events.get(index)),
)
.min_by(|&e0, &e1| {
(e0.time - input_event.time)
.abs()
.partial_cmp(&(e1.time - input_event.time).abs())
.unwrap()
})
.unwrap(),
};
fn vec3(v: v0::Vector3) -> glam::Vec3 {
glam::vec3(v.x, v.y, v.z)
}
fn angles(a: v0::Vector3) -> glam::Vec2 {
glam::vec2(a.y, a.x)
}
let pos = vec3(output_event.event.position) - world_offset;
let angles = angles(output_event.event.angles);
g.generate_inputs(pos, angles, &mut inputs, &mut depth);
}
let device = burn::backend::cuda::CudaDevice::new(gpu_id);
let mut model: Net<TrainingBackend> = Net::init(&device);
let num_params = model.num_params();
println!("Training model ({} parameters)", num_params);
let mut optim = AdamConfig::new().init();
let inputs = Tensor::from_data(
TensorData::new(
inputs,
Shape::new([training_samples, POSITION_HISTORY_SIZE]),
),
&device,
);
let depth = Tensor::from_data(
TensorData::new(depth, Shape::new([training_samples, DEPTH_SIZE])),
&device,
);
let targets = Tensor::from_data(
TensorData::new(targets, Shape::new([training_samples, OUTPUT])),
&device,
);
let mut best_model = model.clone();
let mut best_loss = f32::INFINITY;
for epoch in 0..epochs {
let predictions = model.forward(inputs.clone(), depth.clone());
let loss = MseLoss::new().forward(predictions, targets.clone(), Reduction::Mean);
let loss_scalar = loss.clone().into_scalar();
if epoch == 0 {
// kinda a fake print, but that's what is happening after this point
println!("Compiling optimized GPU kernels...");
}
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
// get the best model
if loss_scalar < best_loss {
best_loss = loss_scalar;
best_model = model.clone();
}
model = optim.step(learning_rate, model, grads);
if epoch % (epochs >> 4) == 0 || epoch == epochs - 1 {
println!(" epoch {epoch:>5} | loss = {loss_scalar:.8} | best_loss = {best_loss:.8}");
}
}
let date_string = format!("{}_{}.model", num_params, best_loss);
best_model
.save_file(
date_string,
&burn::record::BinFileRecorder::<burn::record::FullPrecisionSettings>::new(),
)
.unwrap();
}