jvm-rs/crates/core/src/thread.rs
2025-12-25 10:46:03 +10:30

632 lines
19 KiB
Rust

use crate::class::{ClassRef, InitState, RuntimeClass};
use crate::class_file::{MethodData, MethodRef};
use crate::class_loader::{ClassLoader, LoaderRef};
use crate::error::VmError;
use crate::frame::Frame;
use crate::native::jni::create_jni_function_table;
use crate::objects::object::{ObjectReference, ReferenceKind};
use crate::objects::object_manager::ObjectManager;
use crate::value::{Primitive, Value};
use crate::vm::Vm;
use crate::{
BaseType, FieldType, MethodDescriptor, ThreadId, generate_jni_method_name, set_last_native,
};
use jni::sys::{JNIEnv, jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort};
use libffi::middle::*;
use log::{LevelFilter, debug, trace, warn};
use parking_lot::{Condvar, Mutex, Once, RwLock};
use std::any::Any;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::sync::atomic::Ordering;
use std::sync::{Arc, OnceLock};
use std::thread;
static INIT_LOGGER: Once = Once::new();
type MethodCallResult = Result<Option<Value>, VmError>;
// Thread-local storage for current thread ID
// In single-threaded mode: stores the one thread ID
// In multi-threaded mode: each OS thread has its own thread ID
thread_local! {
static CURRENT_THREAD_ID: RefCell<Option<ThreadId>> = RefCell::new(None);
}
// A thread of execution
pub struct VmThread {
pub id: ThreadId,
pub vm: Arc<Vm>,
pub loader: Arc<Mutex<ClassLoader>>,
pub frame_stack: Mutex<Vec<Frame>>,
pub gc: Arc<RwLock<ObjectManager>>,
pub jni_env: JNIEnv,
pub mirror: OnceLock<u32>,
init_condvar: Condvar,
}
impl VmThread {
pub fn new(vm: Arc<Vm>, loader: Option<LoaderRef>) -> Arc<Self> {
let id = ThreadId(vm.next_id.fetch_add(1, Ordering::SeqCst));
let loader = loader.unwrap_or(vm.loader.clone());
let gc = vm.gc.clone();
Arc::new_cyclic(|weak_self| {
let jni_env = create_jni_function_table(weak_self.as_ptr() as *mut VmThread);
Self {
id,
vm,
loader,
frame_stack: Default::default(),
gc,
jni_env,
mirror: Default::default(),
init_condvar: Default::default(),
}
})
}
/// Get current thread ID from thread-local storage
pub fn current_id() -> ThreadId {
CURRENT_THREAD_ID.with(|cell| cell.borrow().expect("No current thread set"))
}
/// Set current thread ID for this OS thread
pub fn set_current(id: ThreadId) {
CURRENT_THREAD_ID.with(|cell| {
*cell.borrow_mut() = Some(id);
});
}
/// Get current thread from VM using thread-local storage
pub fn current(vm: &Arc<Vm>) -> Arc<VmThread> {
let id = Self::current_id();
vm.threads.get(&id).unwrap().clone()
}
/*/// Get or resolve a class, ensuring it and its dependencies are initialised.
/// Follows JVM Spec 5.5 for recursive initialisation handling.
pub fn get_or_resolve_class(&self, what: &str) -> Result<Arc<RuntimeClass>, VmError> {
// Phase 1: Load the class (short lock)
let runtime_class = self.loader.lock().unwrap().get_or_load(what, None, true)?;
// Phase 2: Collect classes that need initialisation (short lock)
let classes_to_init = {
let mut loader = self.loader.lock().unwrap();
let classes = loader.needs_init.clone();
loader.needs_init.clear();
classes
};
// Phase 3: Initialise each class (NO lock held - allows recursion)
for class in classes_to_init {
self.init(class)?;
}
Ok(runtime_class)
}*/
pub fn get_class(&self, what: &str) -> Result<Arc<RuntimeClass>, VmError> {
let class = self.loader.lock().get_or_load(what, None)?;
self.create_mirror_class(&class)?;
Ok(class)
}
/// Initialize a class following JVM Spec 5.5.
/// Handles recursive initialization by tracking which thread is initializing.
pub fn init(&self, class: Arc<RuntimeClass>) -> Result<(), VmError> {
let current_thread = thread::current().id();
// Check and update initialization state
{
let mut state = class.init_state.lock();
match &*state {
InitState::Initialized => {
// Already initialized, nothing to do
return Ok(());
}
InitState::Initializing(tid) if *tid == current_thread => {
// JVM Spec 5.5: Recursive initialization by same thread is allowed
warn!(
"Class {} already being initialized by this thread (recursive)",
class.this_class
);
return Ok(());
}
InitState::Initializing(_tid) => {
// Different thread is initializing - in a real JVM we'd wait
// For now, just return an error
return Err(VmError::LoaderError(format!(
"Class {} is being initialized by another thread",
class.this_class
)));
}
InitState::Error(msg) => {
return Err(VmError::LoaderError(format!(
"Class {} initialization previously failed: {}",
class.this_class, msg
)));
}
InitState::NotInitialized => {
// Mark as being initialized by this thread
*state = InitState::Initializing(current_thread);
}
}
}
// Perform actual initialisation
trace!("Initializing class: {}", class.this_class);
let result = (|| {
// Initialize superclass first (if any)
if let Some(ref super_class) = class.super_class {
self.init(super_class.clone())?;
}
if !class.access_flags.INTERFACE {
for interface in class.interfaces.iter() {
if interface.has_default_method() {
self.init(interface.clone())?;
}
}
}
// Run <clinit> if present (note: <clinit> is NOT inherited, only look in this class)
if let Some(method) = class.methods.iter().find(|m| m.name == "<clinit>") {
self.execute_method(&class, method, vec![])?;
}
Ok(())
})();
// Update state based on result
{
let mut state = class.init_state.lock();
match result {
Ok(_) => {
*state = InitState::Initialized;
trace!("Class {} initialized successfully", class.this_class);
}
Err(ref e) => {
*state = InitState::Error(format!("{:?}", e));
}
}
}
result
}
pub fn ensure_initialised(&self, class: &Arc<RuntimeClass>) -> Result<(), VmError> {
let current_thread = thread::current().id();
{
let mut state = class.init_state.lock();
match &*state {
InitState::Initialized => return Ok(()),
InitState::Initializing(tid) if *tid == current_thread => return Ok(()),
InitState::Initializing(_) => {
// Wait for the other thread to finish
while matches!(&*state, InitState::Initializing(_)) {
self.init_condvar.wait(&mut state);
}
// Re-check state after waking
match &*state {
InitState::Initialized => return Ok(()),
InitState::Error(msg) => {
return Err(VmError::LoaderError(format!(
"Class {} initialization previously failed: {}",
class.this_class, msg
)));
}
_ => unreachable!(),
}
}
InitState::Error(msg) => {
return Err(VmError::LoaderError(format!(
"Class {} initialization previously failed: {}",
class.this_class, msg
)));
}
InitState::NotInitialized => {
*state = InitState::Initializing(current_thread);
}
}
}
let result = (|| {
if let Some(ref super_class) = class.super_class {
self.ensure_initialised(super_class)?;
}
if !class.access_flags.INTERFACE {
for interface in class.interfaces.iter() {
if interface.has_default_method() {
self.ensure_initialised(interface)?;
}
}
}
if let Some(method) = class.methods.iter().find(|m| m.name == "<clinit>") {
if class.this_class.contains("Thread") {
trace!("executing init for: {}", class.this_class);
}
self.execute_method(class, method, vec![])?;
}
Ok(())
})();
{
let mut state = class.init_state.lock();
match &result {
Ok(_) => *state = InitState::Initialized,
Err(e) => *state = InitState::Error(format!("{:?}", e)),
}
}
result
}
/// creates a mirror java/lang/Class Object and binds it to this RuntimeClass
fn create_mirror_class(&self, class: &Arc<RuntimeClass>) -> Result<(), VmError> {
if class.mirror.get().is_some() || class.mirror_in_progress.swap(true, Ordering::SeqCst) {
return Ok(()); // already has a mirror
}
let class_class = if class.this_class == "java/lang/Class" {
Arc::clone(class)
} else {
self.get_class("java/lang/Class")?
};
let string = self.intern_string(&class.this_class);
let component_type = if (class.this_class.starts_with("[")) {
let component = self.get_class(class.this_class.strip_prefix("[").unwrap())?;
Some(self.gc.read().get(component.mirror()))
} else {
None
};
let class_obj = self.gc.write().new_class(
class_class,
Some(ReferenceKind::ObjectReference(string)),
None,
class.access_flags,
false,
component_type,
);
let id = class_obj.lock().id;
class.mirror.set(id).expect("woops, id already set");
Ok(())
}
pub fn invoke_main(&self, what: &str) -> Result<(), VmError> {
let method_ref = MethodRef {
class: what.to_string(),
name: "main".to_string(),
desc: MethodDescriptor::psvm(),
};
self.invoke(method_ref, Vec::new())?;
Ok(())
}
pub fn invoke(&self, method_reference: MethodRef, args: Vec<Value>) -> MethodCallResult {
if self.gc.read().objects.len() > 2350 {
INIT_LOGGER.call_once(|| {
env_logger::Builder::from_default_env()
.filter_level(LevelFilter::Trace)
.filter_module("deku", LevelFilter::Warn)
.filter_module("roast_vm_core::class_file::class_file", LevelFilter::Info)
.filter_module("roast_vm_core::attributes", LevelFilter::Info)
.filter_module("roast_vm_core::instructions", LevelFilter::Info)
.init();
});
// println!("heap length {}", self.gc.read().objects.len())
}
let class = self.get_class(&method_reference.class)?;
let method = class.find_method(&method_reference.name, &method_reference.desc)?;
let class = self.loader.lock().get_or_load(&*method.class, None)?;
self.execute_method(&class, &method, args)
}
pub fn invoke_virtual(
&self,
method_reference: MethodRef,
class: Arc<RuntimeClass>,
args: Vec<Value>,
) -> MethodCallResult {
let method = class.find_method(&method_reference.name, &method_reference.desc)?;
let class = self.loader.lock().get_or_load(&*method.class, None)?;
self.execute_method(&class, &method, args)
}
pub fn invoke_native(&self, method: &MethodRef, args: Vec<Value>) -> MethodCallResult {
let symbol_name = generate_jni_method_name(method, false);
if symbol_name.contains("Java_java_lang_reflect_Array_newArray") {
return Err(VmError::Debug(
"RoastVM specific implementation required for Java_java_lang_reflect_Array_newArray",
));
}
let result = unsafe {
let p = self
.vm
.find_native_method(&symbol_name)
.or_else(|| {
let name_with_params = generate_jni_method_name(method, true);
self.vm.find_native_method(&name_with_params)
})
.ok_or(VmError::NativeError(format!(
"Link error: Unable to locate symbol {symbol_name}"
)))?;
// build pointer to native fn
let cp = CodePtr::from_ptr(p);
println!("invoke native fn: {}", symbol_name);
let mut storage = Vec::new();
trace!("passing {} to native fn", Value::format_vec(&args));
let deq_args = VecDeque::from(args);
let built_args = build_args(
deq_args,
&mut storage,
&self.jni_env as *const _ as *mut JNIEnv,
);
let cif = method.build_cif();
set_last_native(&symbol_name);
match &method.desc.return_type {
None => {
cif.call::<()>(cp, built_args.as_ref());
Ok(None)
}
Some(FieldType::Base(BaseType::Long)) => {
let v = cif.call::<jlong>(cp, built_args.as_ref());
Ok(Some(Value::Primitive(Primitive::Long(v))))
}
Some(FieldType::Base(BaseType::Int)) => {
let v = cif.call::<jint>(cp, built_args.as_ref());
Ok(Some(v.into()))
}
Some(FieldType::Base(BaseType::Float)) => {
let v = cif.call::<jfloat>(cp, built_args.as_ref());
Ok(Some(v.into()))
}
Some(FieldType::Base(BaseType::Double)) => {
let v = cif.call::<jdouble>(cp, built_args.as_ref());
Ok(Some(v.into()))
}
Some(FieldType::Base(BaseType::Boolean)) => {
let v = cif.call::<jboolean>(cp, built_args.as_ref());
Ok(Some(v.into()))
}
Some(FieldType::Base(BaseType::Byte)) => {
let v = cif.call::<jbyte>(cp, built_args.as_ref());
Ok(Some(v.into()))
}
Some(FieldType::Base(BaseType::Char)) => {
let v = cif.call::<jchar>(cp, built_args.as_ref());
Ok(Some(v.into()))
}
Some(FieldType::Base(BaseType::Short)) => {
let v = cif.call::<jshort>(cp, built_args.as_ref());
Ok(Some(v.into()))
}
Some(FieldType::ClassType(_)) | Some(FieldType::ArrayType(_)) => {
let v = cif.call::<jobject>(cp, built_args.as_ref());
// Convert jobject (u32 ID) to Reference
let obj_id = v as u32;
if obj_id == 0 {
// Null reference
Ok(Some(Value::Reference(None)))
} else {
// Look up the object in the ObjectManager
let gc = self.gc.read();
let reference_kind = gc.get(obj_id);
Ok(Some(Value::Reference(Some(reference_kind))))
}
}
}
};
result
}
fn execute_method(
&self,
class: &Arc<RuntimeClass>,
method: &MethodData,
args: Vec<Value>,
) -> MethodCallResult {
debug!("Executing {}", method.name.clone());
let method_ref = MethodRef {
class: class.this_class.clone(),
name: method.name.clone(),
desc: method.desc.clone(),
};
if method.flags.ACC_NATIVE {
let mut native_args = Vec::new();
if method.flags.ACC_STATIC {
let jclass = self.vm.gc.read().get(*class.mirror.wait());
native_args.push(Value::Reference(Some(jclass)));
}
native_args.extend(args);
let res = self.invoke_native(&method_ref, native_args);
// println!("Returning from native: {}.{}", &method_ref.class, &method_ref.name);
return res;
}
let mut frame = Frame::new(
class.clone(),
method_ref.clone(),
method.code.clone().unwrap(),
class.constant_pool.clone(),
args,
self.vm.clone(),
method.line_number_table.clone(),
);
// let frame = Arc::new(ReentrantMutex::new(frame));
self.frame_stack.lock().push(frame.clone());
// println!("Invoke method: {}.{}", &method_ref.class, &method_ref.name);
let result = frame.execute();
if result.is_ok() {
self.frame_stack.lock().pop();
}
// println!("Returning from method: {}.{}", &method_ref.class, &method_ref.name);
result
}
// pub fn print_stack_trace(&mut self) {
// // Get a lock on the frame stack
// let guard = self.frame_stack.lock();
// // Reverse - most recent frame first (like Java does)
// for frame_arc in guard.iter().rev() {
// // Get a lock on the individual frame
// let frame = frame_arc.lock();
// let method = &frame.method_ref;
// // Internal format uses '/', Java stack traces use '.'
// let class_name = method.class.replace("/", ".");
//
// match (&frame.class.source_file, &frame.current_line_number()) {
// (Some(file), Some(line)) => {
// eprintln!("\tat {}.{}({}:{})", class_name, method.name, file, line)
// }
// (Some(file), None) => eprintln!("\tat {}.{}({})", class_name, method.name, file),
// _ => eprintln!("\tat {}.{}(Unknown Source)", class_name, method.name),
// }
// }
// }
pub fn new_object(
&self,
class: ClassRef,
desc: MethodDescriptor,
args: &[Value],
) -> Result<ObjectReference, VmError> {
let obj = self.gc.write().new_object(class.clone());
let method_ref = class.find_method("<init>", &desc)?;
let mut args = args.to_vec();
args.insert(0, obj.clone().into());
let _ = self.invoke(method_ref.into(), args)?;
Ok(obj)
}
}
fn build_args<'a>(
mut params: VecDeque<Value>,
storage: &'a mut Vec<Box<dyn Any>>,
jnienv: *mut JNIEnv,
) -> Vec<Arg<'a>> {
// Slot 0: JNIEnv
storage.push(Box::new(jnienv));
// Slot 1: this (instance) or class (static) — first param either way
let receiver = params.pop_front();
let receiver_id = match receiver {
Some(Value::Reference(Some(ref_kind))) => ref_kind.id(),
Some(Value::Reference(None)) => 0, // null
_ => panic!("first arg must be reference"),
};
storage.push(Box::new(receiver_id as jobject));
for value in params {
match value {
Value::Primitive(Primitive::Int(x)) => storage.push(Box::new(x) as Box<dyn Any>),
Value::Primitive(Primitive::Boolean(x)) => storage.push(Box::new(x) as Box<dyn Any>),
Value::Primitive(Primitive::Char(x)) => storage.push(Box::new(x) as Box<dyn Any>),
Value::Primitive(Primitive::Float(x)) => storage.push(Box::new(x) as Box<dyn Any>),
Value::Primitive(Primitive::Double(x)) => storage.push(Box::new(x) as Box<dyn Any>),
Value::Primitive(Primitive::Byte(x)) => storage.push(Box::new(x) as Box<dyn Any>),
Value::Primitive(Primitive::Short(x)) => storage.push(Box::new(x) as Box<dyn Any>),
Value::Primitive(Primitive::Long(x)) => storage.push(Box::new(x) as Box<dyn Any>),
Value::Reference(x) => {
let id = x.map(|r| r.id()).unwrap_or(0) as jobject;
storage.push(Box::new(id));
}
Value::Padding => {
panic!("Uhh not possible chief")
}
}
}
// Create args referencing the storage
storage.iter().map(|boxed| arg(&**boxed)).collect()
}
impl From<FieldType> for Type {
fn from(value: FieldType) -> Self {
match value {
FieldType::Base(v) => match v {
BaseType::Byte => Type::i8(),
BaseType::Char => Type::u16(),
BaseType::Double => Type::f64(),
BaseType::Float => Type::f32(),
BaseType::Int => Type::i32(),
BaseType::Long => Type::i64(),
BaseType::Short => Type::i16(),
BaseType::Boolean => Type::i8(),
},
FieldType::ClassType(_) => Self::pointer(),
FieldType::ArrayType(_) => Self::pointer(),
}
}
}
impl MethodRef {
fn build_cif(&self) -> Cif {
let mut args = vec![
Type::pointer(), //JNIEnv*
Type::pointer(), //jclass
];
for v in self.desc.parameters.clone() {
args.push(v.into())
}
let return_type = if let Some(x) = self.desc.return_type.clone() {
x.into()
} else {
Type::void()
};
Builder::new().args(args).res(return_type).into_cif()
}
}
impl VmThread {
/// perhaps misleadingly named
/// was once called get_or_make_string
/// it will look for an already created string, and if its exists, return it
/// if not, will cause a new String to be made, which at the time always interns it
pub fn intern_string(&self, utf: &str) -> ObjectReference {
if let Some(existing) = self.gc.read().get_interned_string(utf) {
return existing;
}
let string_class = self.get_class("java/lang/String").unwrap();
let byte_array_class = self.get_class("[B").unwrap();
let mut gc = self.gc.write();
gc.get_interned_string(utf)
.unwrap_or_else(|| gc.new_string(byte_array_class, string_class, utf))
}
pub fn bootstrap_mirror(&self) {
let thread_id = self.id.0 as jlong;
let main = self.intern_string("main");
let system = self.intern_string("system");
let thread_klass = self.get_class("java/lang/Thread").unwrap();
let thread_group_klass = self.get_class("java/lang/ThreadGroup").unwrap();
let field_holder_klass = self.get_class("java/lang/Thread$FieldHolder").unwrap();
let group = self
.gc
.write()
.new_thread_group(thread_group_klass, system, 10 as jint);
let holder = self
.gc
.write()
.new_holder(field_holder_klass, group, 5 as jint, 5 as jint);
let thread = self
.gc
.write()
.new_thread(thread_klass, thread_id, thread_id, main, holder);
self.mirror.set(thread.lock().id).unwrap();
}
}
// SAFETY:
// - jni_env pointer is owned by this VmThread and only used by its OS thread
// - Frame stack is behind Mutex
unsafe impl Send for VmThread {}
unsafe impl Sync for VmThread {}