Browse Source

added code to initialize GPU context

main
Thomas Johnson 9 months ago
parent
commit
409360d632
  1. 3
      Cargo.toml
  2. 5
      src/main.rs
  3. 18
      src/map.rs
  4. 91
      src/map/gpu.rs
  5. 4
      src/map/shader.glsl

3
Cargo.toml

@ -10,4 +10,5 @@ rand = "0.8"
rand_distr = "0.4"
rayon = "1.5"
bytemuck = "1.7.3"
wgpu = "0.12"
wgpu = { version = "0.12", features = [ "spirv" ] }
pollster = "0.2.4"

5
src/main.rs

@ -19,6 +19,11 @@ fn main() {
Rgb([pix[0] as f32 / 65535.0, pix[1] as f32 / 65535.0, pix[2] as f32 / 65535.0])
});
let gpu_ctx = map::gpu::init(&image_in);
if let Some(_) = gpu_ctx {
println!("successfully created a GPU context");
}
let mut rng = rand::thread_rng();
let cfg = MapGeneticConfig::new(
image_in.clone(),

18
src/map.rs

@ -5,7 +5,7 @@ use rand::{Rng, distributions::{Uniform, WeightedIndex}};
use rand_distr::{Distribution, Bernoulli};
use image::{Rgb, GenericImage, ImageBuffer};
mod gpu;
pub mod gpu;
#[derive(Debug)]
pub struct Map
@ -104,14 +104,14 @@ impl Genetic for Map
}
fn mutate<R: Rng>(&self, mut rng: &mut R, cfg: &Self::Configuration) -> Self {
// first, remove elements
let mut new = self.clone();
// first, remove elements
let mut accum = 1.0;
let threshold = (cfg.mean_add as f32).exp().recip();
accum *= rng.gen::<f32>();
while accum > threshold {
let transform = cfg.random_transform(rng);
new.transforms.push(transform);
while new.transforms.len() != 0 && accum > threshold {
let which = rng.gen_range(0..new.transforms.len());
new.transforms.swap_remove(which);
accum *= rng.gen::<f32>();
}
@ -120,7 +120,7 @@ impl Genetic for Map
let threshold = (cfg.mean_add as f32).exp().recip();
accum *= rng.gen::<f32>();
while accum > threshold {
let which = rng.gen_range(0..self.transforms.len());
let which = rng.gen_range(0..new.transforms.len());
let mut matrix_coeff_dist = Uniform::new(-1.0, 1.0).sample_iter(&mut rng);
for i in 0..NCOLORS * NCOLORS {
new.transforms[which].color_matrix[i] = matrix_coeff_dist.next().unwrap();
@ -132,9 +132,9 @@ impl Genetic for Map
let mut accum = 1.0;
let threshold = (cfg.mean_add as f32).exp().recip();
accum *= rng.gen::<f32>();
while new.transforms.len() != 0 && accum > threshold {
let which = rng.gen_range(0..new.transforms.len());
new.transforms.swap_remove(which);
while accum > threshold {
let transform = cfg.random_transform(rng);
new.transforms.push(transform);
accum *= rng.gen::<f32>();
}
new

91
src/map/gpu.rs

@ -1,10 +1,97 @@
use wgpu::{Instance, Backends, Adapter, Device, Queue, ShaderModule, Buffer};
use wgpu::{Instance, Backends, Adapter, DeviceType, Features, Limits, DeviceDescriptor, Device, Queue, ShaderModuleDescriptor, ShaderSource, ShaderModule, BufferDescriptor, util::{BufferInitDescriptor, DeviceExt}, BufferUsages, Buffer};
use image::{ImageBuffer, Rgb};
use std::task::Poll;
use std::future::Future;
struct GpuContext {
#[repr(C)]
struct ImageFormat {
channel_stride: u32,
width: u32,
width_stride: u32,
height: u32,
height_stride: u32,
}
pub struct GpuContext {
device: Device,
queue: Queue,
shader_module: ShaderModule,
goal_image_buffer: Buffer,
image_format_buffer: Buffer,
scratch_image_buffers: (Buffer, Buffer),
}
pub fn init(goal_image: &ImageBuffer<Rgb<f32>, Vec<f32>>) -> Option<GpuContext> {
let inst = Instance::new(Backends::all());
// Look for devices with the features we want
let required_features = Features::SHADER_FLOAT64 | Features::CLEAR_COMMANDS;
let adapters: Vec<_> = inst.enumerate_adapters(Backends::all()).filter(|a| a.features().contains(required_features)).collect();
if adapters.is_empty() {
return None;
}
// Pick the best type of device available
let mut target_device_type = DeviceType::Other;
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::Cpu) {
target_device_type = DeviceType::Cpu;
}
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::VirtualGpu) {
target_device_type = DeviceType::VirtualGpu;
}
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::IntegratedGpu) {
target_device_type = DeviceType::IntegratedGpu;
}
if adapters.iter().any(|a| a.get_info().device_type == DeviceType::DiscreteGpu) {
target_device_type = DeviceType::DiscreteGpu;
}
println!("available adapters:");
for (i, adapter) in adapters.iter().enumerate() {
let info = adapter.get_info();
println!("{}: {}", i, info.name);
}
let (i, adapter) = adapters.into_iter().enumerate().filter(|(_, a)| a.get_info().device_type == target_device_type).next().unwrap();
println!("picking {}: {}", i, adapter.get_info().name);
let limits = adapter.limits();
let device_descriptor = DeviceDescriptor {
label: None,
features: required_features,
limits: limits.clone(),
};
let (device, queue) = pollster::block_on(adapter.request_device(&device_descriptor, None)).ok()?;
let sample_layout = goal_image.sample_layout();
let image_format = ImageFormat {
channel_stride: sample_layout.channel_stride as u32,
width: sample_layout.width,
width_stride: sample_layout.width_stride as u32,
height: sample_layout.height,
height_stride: sample_layout.height_stride as u32,
};
let image_format_buffer = device.create_buffer_init(&BufferInitDescriptor { label: None, contents: unsafe { std::slice::from_raw_parts(&image_format as *const ImageFormat as *const u8, std::mem::size_of::<ImageFormat>()) }, usage: BufferUsages::UNIFORM });
let goal_image_storage = &*goal_image.as_raw();
let goal_image_buffer = device.create_buffer_init(&BufferInitDescriptor { label: None, contents: bytemuck::cast_slice(goal_image_storage), usage: BufferUsages::UNIFORM | BufferUsages::STORAGE });
let scratch_image_buffer_gen = || device.create_buffer(&BufferDescriptor { label: None, size: (std::mem::size_of::<f32>() as u64 * image_format.width as u64 * image_format.height as u64), usage: BufferUsages::COPY_DST | BufferUsages::STORAGE | BufferUsages::MAP_READ, mapped_at_creation: false });
let scratch_image_buffers = (scratch_image_buffer_gen(), scratch_image_buffer_gen());
let shader_module = device.create_shader_module(&wgpu::include_spirv!("shader.spv"));
let context = GpuContext {
device,
queue,
shader_module,
goal_image_buffer,
image_format_buffer,
scratch_image_buffers,
};
Some(context)
}

4
src/map/shader.glsl

@ -0,0 +1,4 @@
#version 450
void main() {
}
Loading…
Cancel
Save