Hands-on Examples
From basic data loading to training ISAC baselines.
Usage & Examples
LAMBDA follows Sionna's Right-Handed Coordinate System.
- Basic: Load CSI
- Task 1: Localization
- Task 2: Beam Prediction
This example demonstrates how to load and inspect the compressed .npz multipath files.
import numpy as np
# Load the compressed CSI file
data = np.load("path/to/csi_xxxxxx.npz")
# Access Multipath Components
a_real = data['a_real']
a_imag = data['a_imag']
delays = data['tau']
doppler = data['doppler']
# Reconstruct Complex Gain
complex_gain = a_real + 1j * a_imag
# Access Angles (AoD / AoA)
theta_t, phi_t = data['theta_t'], data['phi_t']
theta_r, phi_r = data['theta_r'], data['phi_r']
print(f"Detected {len(delays)} paths.")
print(f"Max Doppler Shift: {np.max(np.abs(doppler)):.2f} Hz")
This example demonstrates an end-to-end pipeline for UAV Localization. It defines a custom PyTorch Dataset that loads RGB images, Depth maps, and multipath-based labels, and trains a regression network (modified MobileNetV2) to predict 3D coordinates.
Step 1: Dataset Preparation
Define a Dataset class to handle multi-modal inputs.
Show Code
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
# Camera Position: X, Y, Z, Pitch, Yaw, Roll
CAMERA_POSE = np.array([-8.10, -157, -35.7, 40, -180, 0], dtype=np.float32)
# Paths
CSV_PATH = r"d:\exp\beam_labels.csv"
RGB_DIR = r"E:\Datasets1\San Francisco\Scene 1\roof_bs_01\cam"
DEPTH_DIR = r"E:\Datasets1\San Francisco\Scene 1\roof_bs_01\depth"
CSI_DIR = r"D:\multi_path_npz" # Path to csi_xxxx.npz files containing uav_pos
BATCH_SIZE = 16
EPOCHS = 20
LEARNING_RATE = 0.001
# --- 1. Dataset ---
class UAVLocDataset(Dataset):
def __init__(self, csv_file, rgb_dir, depth_dir, csi_dir, transform=None):
self.df = pd.read_csv(csv_file)
self.rgb_dir = rgb_dir
self.depth_dir = depth_dir
self.csi_dir = csi_dir
# Static camera pose repeated for each sample
self.camera_pose = torch.tensor(CAMERA_POSE, dtype=torch.float32)
if transform is None:
self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
self.transform = transform
# Depth transform (resize and to tensor)
self.depth_transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# 1. Filenames
row = self.df.iloc[idx]
npz_filename = row['filename'] # csi_xxxxxx.npz
file_id = npz_filename.replace('csi_', '').replace('.npz', '')
rgb_name = f"img_{file_id}.png"
depth_name = f"depth_{file_id}.npz"
rgb_path = os.path.join(self.rgb_dir, rgb_name)
depth_path = os.path.join(self.depth_dir, depth_name)
csi_path = os.path.join(self.csi_dir, npz_filename)
# 2. Load Inputs (RGB + Depth)
try:
# RGB
rgb_img = Image.open(rgb_path).convert('RGB')
rgb_tensor = self.transform(rgb_img)
# Depth
with np.load(depth_path) as data:
key = data.files[0]
depth_arr = data[key]
depth_img = Image.fromarray(depth_arr, mode='F')
depth_tensor = self.depth_transform(depth_img) # (1, 512, 512)
# Early Fusion: Concatenate RGB and Depth -> (4, 512, 512)
input_img = torch.cat([rgb_tensor, depth_tensor], dim=0)
except Exception as e:
print(f"Error loading images for {file_id}: {e}")5
input_img = torch.zeros(4, 512, 512)
# 3. Load Label (UAV Position)
try:
with np.load(csi_path) as data:
# 'uav_pos' contains [x, z, y] implicitly based on description
# Description: y is in 3rd dim (index 2), z in 2nd dim (index 1)
# target y needs negation
raw_pos = data['uav_pos'] # shape expected (3,)
# Coordinate transformation
# Target: (x, y, z)
# Ensure 1D and convert to float to avoid shape mismatches (e.g. (3,1) vs (3,))
raw_pos = raw_pos.flatten()
t_x = float(raw_pos[0])
t_z = float(raw_pos[1]) # z is index 1
t_y = -float(raw_pos[2]) # y is index 2, needs negation
# Final order: x, y, z
target_pos = torch.tensor([t_x, t_y, t_z], dtype=torch.float32)
except Exception as e:
print(f"Error loading label for {file_id}: {e}")
target_pos = torch.zeros(3, dtype=torch.float32)
return input_img, self.camera_pose, target_pos
Step 2: Model Architecture
Modify MobileNetV2 to accept 4-channel inputs (RGB+D). The classifier is replaced with a regression head that fuses image features with camera pose to output the (x, y, z) coordinates.
Show Code
# --- 2. Network Model ---
class LocRegressionNet(nn.Module):
def __init__(self, pose_dim=6, output_dim=3):
super(LocRegressionNet, self).__init__()
# Backbone (MobileNetV2)
# Modified input to 4 channels (RGBD)
self.backbone = models.mobilenet_v2(weights=None)
original_first_layer = self.backbone.features[0][0]
self.backbone.features[0][0] = nn.Conv2d(
in_channels=4,
out_channels=original_first_layer.out_channels,
kernel_size=original_first_layer.kernel_size,
stride=original_first_layer.stride,
padding=original_first_layer.padding,
bias=False
)
# Remove classifier, use features only
# MobileNetV2 features output: (1280, 7, 7) for 224x224 input
# For 512x512 input, spatial dim will be larger (16x16)
# We use Global Average Pooling to get (1280,)
self.gap = nn.AdaptiveAvgPool2d(1)
self.feature_extractor = self.backbone.features
feature_dim = 1280 # MobileNetV2 last channel
# Regression Head
# Concatenate Image Features (1280) + Camera Pose (6) = 1286
self.regressor = nn.Sequential(
nn.Linear(feature_dim + pose_dim, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, output_dim) # Output: x, y, z
)
def forward(self, x_img, x_pose):
# Image Features
x = self.feature_extractor(x_img) # (B, 1280, H, W)
x = self.gap(x) # (B, 1280, 1, 1)
x = torch.flatten(x, 1) # (B, 1280)
# Fusion
combined = torch.cat((x, x_pose), dim=1) # (B, 1286)
# Regression
output = self.regressor(combined)
return output
Step 3: Training & Evaluation
Use MSE Loss for regression.
Show Code
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Data
print("Preparing Data...")
if not os.path.exists(CSV_PATH):
print(f"Error: CSV file not found at {CSV_PATH}")
return
full_df = pd.read_csv(CSV_PATH)
train_df, test_df = train_test_split(full_df, test_size=0.2, random_state=42)
# Save temp CSVs
train_df.to_csv('temp_loc_train.csv', index=False)
test_df.to_csv('temp_loc_test.csv', index=False)
train_dataset = UAVLocDataset('temp_loc_train.csv', RGB_DIR, DEPTH_DIR, CSI_DIR)
test_dataset = UAVLocDataset('temp_loc_test.csv', RGB_DIR, DEPTH_DIR, CSI_DIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
# Model
model = LocRegressionNet().to(device)
# Loss: MSE for regression
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
print("Starting Training (Regression Task)...")
best_loss = float('inf')
for epoch in range(EPOCHS):
model.train()
running_loss = 0.0
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
for imgs, poses, targets in loop:
imgs = imgs.to(device)
poses = poses.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(imgs, poses)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
loop.set_postfix(mse_loss=loss.item())
avg_train_loss = running_loss / len(train_loader)
# Evaluation
model.eval()
test_loss = 0.0
with torch.no_grad():
for imgs, poses, targets in test_loader:
imgs = imgs.to(device)
poses = poses.to(device)
targets = targets.to(device)
outputs = model(imgs, poses)
loss = criterion(outputs, targets)
test_loss += loss.item()
avg_test_loss = test_loss / len(test_loader)
print(f"Epoch {epoch+1}: Train MSE={avg_train_loss:.4f}, Test MSE={avg_test_loss:.4f}, Test RMSE={np.sqrt(avg_test_loss):.4f}")
if avg_test_loss < best_loss:
best_loss = avg_test_loss
torch.save(model.state_dict(), "best_uav_loc_model.pth")
print(" [Saved Best Model]")
print(f"Training Finished! Best Test MSE: {best_loss:.4f}")
# Clean up
if os.path.exists('temp_loc_train.csv'): os.remove('temp_loc_train.csv')
if os.path.exists('temp_loc_test.csv'): os.remove('temp_loc_test.csv')
if __name__ == "__main__":
main()
Step 4: Experimental Results
After a parallel training ablation study over 100 epochs using the San Francisco Block 1 dataset, RGB+D achieves the best localization performance with a final RMSE of 2.58 m (MSE: 6.6631). This is followed by RGB only (RMSE: 2.93 m), while Depth alone reaches an RMSE of 3.94 m. Notably, Depth converges significantly faster in the earliest epochs, dropping from an initial MSE of 889.50 to 43.41 by the second epoch, which indicates that geometric structure provides a strong inductive bias for spatial regression in urban environments. With sufficient training, RGB catches up by utilizing semantic and texture cues, and fusion ultimately benefits from complementary geometry and appearance, yielding the highest final precision.
Show Full Training Logs
Ep 1/100: 100%|██████████████████████████████| 2003/2003 [08:28<00:00, 3.94it/s, rgb=639.06, depth=218.95, rgbd=329.28]
Stats Ep 1: rgb: 310.8004 | rgbd: 3077.3468 | depth: 889.5046
Ep 2/100: 100%|██████████████████████████████| 2003/2003 [06:14<00:00, 5.35it/s, rgb=626.10, depth=121.41, rgbd=142.21]
Stats Ep 2: rgb: 314.0181 | rgbd: 2039.5476 | depth: 43.4116
Ep 3/100: 100%|██████████████████████████████| 2003/2003 [03:34<00:00, 9.32it/s, rgb=402.87, depth=358.99, rgbd=611.07]
Stats Ep 3: rgb: 307.2469 | rgbd: 943.4265 | depth: 884.8918
Ep 4/100: 100%|██████████████████████████████| 2003/2003 [03:22<00:00, 9.91it/s, rgb=417.30, depth=350.50, rgbd=476.19]
Stats Ep 4: rgb: 305.9102 | rgbd: 689.0194 | depth: 102.0278
Ep 5/100: 100%|██████████████████████████████| 2003/2003 [04:01<00:00, 8.30it/s, rgb=531.35, depth=130.76, rgbd=205.58]
Stats Ep 5: rgb: 306.4170 | rgbd: 882.1281 | depth: 26.6658
Ep 6/100: 100%|██████████████████████████████| 2003/2003 [03:30<00:00, 9.53it/s, rgb=622.02, depth=138.42, rgbd=365.98]
Stats Ep 6: rgb: 306.3479 | rgbd: 709.2100 | depth: 18.4566
Ep 7/100: 100%|██████████████████████████████| 2003/2003 [03:31<00:00, 9.47it/s, rgb=372.38, depth=150.94, rgbd=100.19]
Stats Ep 7: rgb: 303.1200 | rgbd: 467.0916 | depth: 15.4920
Ep 8/100: 100%|██████████████████████████████| 2003/2003 [03:41<00:00, 9.05it/s, rgb=353.09, depth=299.04, rgbd=185.78]
Stats Ep 8: rgb: 302.0047 | rgbd: 543.7611 | depth: 72.3152
Ep 9/100: 100%|███████████████████████████████| 2003/2003 [03:15<00:00, 10.25it/s, rgb=455.31, depth=102.35, rgbd=88.01]
Stats Ep 9: rgb: 304.3221 | rgbd: 196.2449 | depth: 67.1170
Ep 10/100: 100%|███████████████████████████████| 2003/2003 [03:41<00:00, 9.03it/s, rgb=388.10, depth=54.54, rgbd=61.70]
Stats Ep 10: rgb: 307.8559 | rgbd: 642.5213 | depth: 667.3951
Ep 11/100: 100%|███████████████████████████████| 2003/2003 [03:13<00:00, 10.33it/s, rgb=707.45, depth=75.29, rgbd=81.29]
Stats Ep 11: rgb: 306.6384 | rgbd: 462.5986 | depth: 551.4559
Ep 12/100: 100%|███████████████████████████████| 2003/2003 [03:43<00:00, 8.96it/s, rgb=350.71, depth=30.32, rgbd=21.00]
Stats Ep 12: rgb: 293.6508 | rgbd: 475.4710 | depth: 58.7112
Ep 13/100: 100%|███████████████████████████████| 2003/2003 [03:19<00:00, 10.02it/s, rgb=202.38, depth=35.08, rgbd=25.44]
Stats Ep 13: rgb: 302.0266 | rgbd: 537.9584 | depth: 43.1008
Ep 14/100: 100%|███████████████████████████████| 2003/2003 [03:56<00:00, 8.45it/s, rgb=227.19, depth=32.19, rgbd=13.25]
Stats Ep 14: rgb: 304.0173 | rgbd: 282.8194 | depth: 41.2770
Ep 15/100: 100%|███████████████████████████████| 2003/2003 [03:25<00:00, 9.75it/s, rgb=251.82, depth=85.67, rgbd=83.02]
Stats Ep 15: rgb: 215.2936 | rgbd: 442.7491 | depth: 76.1542
Ep 16/100: 100%|███████████████████████████████| 2003/2003 [03:42<00:00, 9.00it/s, rgb=100.98, depth=27.60, rgbd=50.09]
Stats Ep 16: rgb: 95.2590 | rgbd: 158.3139 | depth: 31.9029
Ep 17/100: 100%|████████████████████████████████| 2003/2003 [03:37<00:00, 9.19it/s, rgb=65.15, depth=21.09, rgbd=15.65]
Stats Ep 17: rgb: 82.7200 | rgbd: 228.9488 | depth: 30.9142
Ep 18/100: 100%|████████████████████████████████| 2003/2003 [03:20<00:00, 9.99it/s, rgb=35.04, depth=19.12, rgbd=19.44]
Stats Ep 18: rgb: 65.0421 | rgbd: 515.7509 | depth: 33.0823
Ep 19/100: 100%|████████████████████████████████| 2003/2003 [03:25<00:00, 9.76it/s, rgb=79.00, depth=34.07, rgbd=72.35]
Stats Ep 19: rgb: 68.0477 | rgbd: 422.5246 | depth: 26.3802
Ep 20/100: 100%|████████████████████████████████| 2003/2003 [03:43<00:00, 8.98it/s, rgb=84.99, depth=30.89, rgbd=38.08]
Stats Ep 20: rgb: 74.8206 | rgbd: 402.6806 | depth: 22.5832
Ep 21/100: 100%|█████████████████████████████████| 2003/2003 [03:51<00:00, 8.67it/s, rgb=22.25, depth=15.03, rgbd=7.81]
Stats Ep 21: rgb: 71.1198 | rgbd: 354.0043 | depth: 31.6384
Ep 22/100: 100%|████████████████████████████████| 2003/2003 [03:14<00:00, 10.31it/s, rgb=47.78, depth=14.22, rgbd=22.33]
Stats Ep 22: rgb: 64.2219 | rgbd: 19.6502 | depth: 30.0421
Ep 23/100: 100%|████████████████████████████████| 2003/2003 [03:23<00:00, 9.83it/s, rgb=14.02, depth=14.03, rgbd=32.45]
Stats Ep 23: rgb: 43.9017 | rgbd: 47.0878 | depth: 24.1906
Ep 24/100: 100%|████████████████████████████████| 2003/2003 [03:23<00:00, 9.82it/s, rgb=20.07, depth=22.95, rgbd=13.98]
Stats Ep 24: rgb: 32.3447 | rgbd: 15.3120 | depth: 23.2404
Ep 25/100: 100%|████████████████████████████████| 2003/2003 [03:32<00:00, 9.42it/s, rgb=20.88, depth=70.07, rgbd=11.84]
Stats Ep 25: rgb: 20.3500 | rgbd: 398.7645 | depth: 22.8858
Ep 26/100: 100%|████████████████████████████████| 2003/2003 [03:32<00:00, 9.44it/s, rgb=13.29, depth=22.21, rgbd=26.74]
Stats Ep 26: rgb: 18.3350 | rgbd: 17.0506 | depth: 20.4216
Ep 27/100: 100%|█████████████████████████████████| 2003/2003 [03:28<00:00, 9.63it/s, rgb=6.05, depth=21.77, rgbd=16.99]
Stats Ep 27: rgb: 19.1550 | rgbd: 42.2067 | depth: 19.2607
Ep 28/100: 100%|████████████████████████████████| 2003/2003 [03:37<00:00, 9.20it/s, rgb=18.11, depth=10.99, rgbd=12.60]
Stats Ep 28: rgb: 14.9264 | rgbd: 82.7620 | depth: 22.0940
Ep 29/100: 100%|████████████████████████████████| 2003/2003 [03:28<00:00, 9.59it/s, rgb=33.31, depth=13.02, rgbd=13.52]
Stats Ep 29: rgb: 16.0388 | rgbd: 430.7937 | depth: 42.4016
Ep 30/100: 100%|█████████████████████████████████| 2003/2003 [03:22<00:00, 9.87it/s, rgb=21.54, depth=17.04, rgbd=9.08]
Stats Ep 30: rgb: 11.5473 | rgbd: 203.8561 | depth: 18.0487
Ep 31/100: 100%|███████████████████████████████| 2003/2003 [03:17<00:00, 10.15it/s, rgb=210.86, depth=32.25, rgbd=25.20]
Stats Ep 31: rgb: 10.1994 | rgbd: 12.9020 | depth: 23.8446
Ep 32/100: 100%|█████████████████████████████████| 2003/2003 [03:14<00:00, 10.30it/s, rgb=13.98, depth=23.77, rgbd=8.67]
Stats Ep 32: rgb: 11.5724 | rgbd: 26.1964 | depth: 20.1888
Ep 33/100: 100%|████████████████████████████████| 2003/2003 [03:42<00:00, 9.01it/s, rgb=13.91, depth=13.11, rgbd=14.64]
Stats Ep 33: rgb: 12.0045 | rgbd: 19.7310 | depth: 19.1738
Ep 34/100: 100%|█████████████████████████████████| 2003/2003 [03:19<00:00, 10.02it/s, rgb=8.83, depth=15.48, rgbd=11.42]
Stats Ep 34: rgb: 9.6920 | rgbd: 411.2325 | depth: 20.5772
Ep 35/100: 100%|████████████████████████████████| 2003/2003 [03:43<00:00, 8.95it/s, rgb=17.46, depth=26.63, rgbd=43.02]
Stats Ep 35: rgb: 11.6634 | rgbd: 23.1183 | depth: 18.8303
Ep 36/100: 100%|█████████████████████████████████| 2003/2003 [03:21<00:00, 9.96it/s, rgb=9.50, depth=11.77, rgbd=27.22]
Stats Ep 36: rgb: 8.5831 | rgbd: 15.1324 | depth: 19.4143
Ep 37/100: 100%|████████████████████████████████| 2003/2003 [03:54<00:00, 8.55it/s, rgb=21.91, depth=11.54, rgbd=14.98]
Stats Ep 37: rgb: 10.0505 | rgbd: 12.1884 | depth: 19.8651
Ep 38/100: 100%|████████████████████████████████| 2003/2003 [03:35<00:00, 9.32it/s, rgb=11.41, depth=12.67, rgbd=15.16]
Stats Ep 38: rgb: 12.0865 | rgbd: 14.3607 | depth: 21.4308
Ep 39/100: 100%|██████████████████████████████████| 2003/2003 [03:27<00:00, 9.63it/s, rgb=9.50, depth=13.99, rgbd=6.27]
Stats Ep 39: rgb: 9.5172 | rgbd: 11.1677 | depth: 20.5946
Ep 40/100: 100%|██████████████████████████████████| 2003/2003 [03:59<00:00, 8.36it/s, rgb=16.41, depth=9.26, rgbd=8.56]
Stats Ep 40: rgb: 9.3472 | rgbd: 43.2814 | depth: 18.9526
Ep 41/100: 100%|█████████████████████████████████| 2003/2003 [03:39<00:00, 9.14it/s, rgb=11.89, depth=12.34, rgbd=5.48]
Stats Ep 41: rgb: 9.0334 | rgbd: 177.1736 | depth: 20.1144
Ep 42/100: 100%|████████████████████████████████| 2003/2003 [04:02<00:00, 8.25it/s, rgb=16.41, depth=22.63, rgbd=15.82]
Stats Ep 42: rgb: 12.0604 | rgbd: 230.0422 | depth: 21.6746
Ep 43/100: 100%|█████████████████████████████████| 2003/2003 [03:23<00:00, 9.85it/s, rgb=9.13, depth=13.72, rgbd=12.86]
Stats Ep 43: rgb: 10.1213 | rgbd: 10.7602 | depth: 18.9803
Ep 44/100: 100%|████████████████████████████████| 2003/2003 [04:12<00:00, 7.92it/s, rgb=12.48, depth=50.45, rgbd=12.10]
Stats Ep 44: rgb: 9.3887 | rgbd: 12.3186 | depth: 20.2931
Ep 45/100: 100%|███████████████████████████████████| 2003/2003 [03:24<00:00, 9.78it/s, rgb=3.49, depth=7.05, rgbd=5.68]
Stats Ep 45: rgb: 10.8346 | rgbd: 25.5786 | depth: 19.8172
Ep 46/100: 100%|███████████████████████████████████| 2003/2003 [03:59<00:00, 8.36it/s, rgb=6.69, depth=9.45, rgbd=5.41]
Stats Ep 46: rgb: 10.2943 | rgbd: 10.2819 | depth: 18.2244
Ep 47/100: 100%|███████████████████████████████████| 2003/2003 [03:28<00:00, 9.62it/s, rgb=3.13, depth=4.34, rgbd=6.76]
Stats Ep 47: rgb: 9.2400 | rgbd: 10.0706 | depth: 19.9477
Ep 48/100: 100%|█████████████████████████████████| 2003/2003 [03:45<00:00, 8.90it/s, rgb=9.22, depth=24.94, rgbd=17.79]
Stats Ep 48: rgb: 10.0708 | rgbd: 66.7637 | depth: 19.7785
Ep 49/100: 100%|███████████████████████████████████| 2003/2003 [03:49<00:00, 8.72it/s, rgb=2.99, depth=6.92, rgbd=5.06]
Stats Ep 49: rgb: 10.2520 | rgbd: 310.7913 | depth: 20.0406
Ep 50/100: 100%|███████████████████████████████████| 2003/2003 [03:19<00:00, 10.04it/s, rgb=6.60, depth=9.30, rgbd=9.28]
Stats Ep 50: rgb: 10.4128 | rgbd: 544.3668 | depth: 17.8276
Ep 51/100: 100%|████████████████████████████████| 2003/2003 [03:42<00:00, 9.00it/s, rgb=14.67, depth=25.36, rgbd=20.45]
Stats Ep 51: rgb: 10.1103 | rgbd: 9.9542 | depth: 20.5482
Ep 52/100: 100%|████████████████████████████████| 2003/2003 [03:36<00:00, 9.27it/s, rgb=15.20, depth=45.26, rgbd=36.52]
Stats Ep 52: rgb: 11.1264 | rgbd: 15.9574 | depth: 20.1303
Ep 53/100: 100%|███████████████████████████████████| 2003/2003 [03:37<00:00, 9.22it/s, rgb=3.36, depth=8.11, rgbd=5.21]
Stats Ep 53: rgb: 11.2692 | rgbd: 16.9219 | depth: 20.3264
Ep 54/100: 100%|████████████████████████████████| 2003/2003 [03:53<00:00, 8.59it/s, rgb=13.29, depth=19.29, rgbd=11.41]
Stats Ep 54: rgb: 10.1431 | rgbd: 10.5608 | depth: 19.5169
Ep 55/100: 100%|██████████████████████████████████| 2003/2003 [03:32<00:00, 9.43it/s, rgb=8.54, depth=23.72, rgbd=8.05]
Stats Ep 55: rgb: 10.6635 | rgbd: 10.0462 | depth: 19.6956
Ep 56/100: 100%|█████████████████████████████████| 2003/2003 [03:40<00:00, 9.07it/s, rgb=9.80, depth=27.86, rgbd=14.00]
Stats Ep 56: rgb: 11.5389 | rgbd: 9.6502 | depth: 19.1459
Ep 57/100: 100%|███████████████████████████████████| 2003/2003 [04:36<00:00, 7.24it/s, rgb=6.28, depth=4.51, rgbd=6.18]
Stats Ep 57: rgb: 9.9811 | rgbd: 10.6208 | depth: 18.1353
Ep 58/100: 100%|██████████████████████████████████| 2003/2003 [03:48<00:00, 8.76it/s, rgb=6.81, depth=8.42, rgbd=11.55]
Stats Ep 58: rgb: 11.1203 | rgbd: 8.4378 | depth: 19.0683
Ep 59/100: 100%|██████████████████████████████████| 2003/2003 [04:30<00:00, 7.42it/s, rgb=7.23, depth=16.25, rgbd=4.34]
Stats Ep 59: rgb: 10.7293 | rgbd: 13.8946 | depth: 19.2964
Ep 60/100: 100%|███████████████████████████████████| 2003/2003 [03:36<00:00, 9.26it/s, rgb=4.76, depth=6.33, rgbd=6.50]
Stats Ep 60: rgb: 10.7495 | rgbd: 9.9907 | depth: 19.2687
Ep 61/100: 100%|██████████████████████████████████| 2003/2003 [04:13<00:00, 7.89it/s, rgb=3.62, depth=8.61, rgbd=10.02]
Stats Ep 61: rgb: 10.1712 | rgbd: 9.0362 | depth: 19.0882
Ep 62/100: 100%|██████████████████████████████████| 2003/2003 [04:08<00:00, 8.06it/s, rgb=7.16, depth=20.66, rgbd=4.65]
Stats Ep 62: rgb: 10.7153 | rgbd: 10.4034 | depth: 20.1984
Ep 63/100: 100%|████████████████████████████████| 2003/2003 [03:29<00:00, 9.56it/s, rgb=33.07, depth=26.90, rgbd=38.74]
Stats Ep 63: rgb: 11.7746 | rgbd: 9.6731 | depth: 20.4930
Ep 64/100: 100%|████████████████████████████████| 2003/2003 [04:00<00:00, 8.32it/s, rgb=12.35, depth=16.26, rgbd=13.93]
Stats Ep 64: rgb: 11.8854 | rgbd: 9.4098 | depth: 19.8922
Ep 65/100: 100%|██████████████████████████████████| 2003/2003 [03:32<00:00, 9.41it/s, rgb=8.73, depth=17.59, rgbd=8.63]
Stats Ep 65: rgb: 11.0344 | rgbd: 8.3085 | depth: 17.8676
Ep 66/100: 100%|████████████████████████████████| 2003/2003 [03:43<00:00, 8.95it/s, rgb=16.67, depth=37.13, rgbd=14.22]
Stats Ep 66: rgb: 12.0704 | rgbd: 9.2295 | depth: 20.5916
Ep 67/100: 100%|██████████████████████████████████| 2003/2003 [03:28<00:00, 9.61it/s, rgb=7.01, depth=13.76, rgbd=6.28]
Stats Ep 67: rgb: 11.1012 | rgbd: 9.6036 | depth: 19.6591
Ep 68/100: 100%|██████████████████████████████████| 2003/2003 [03:50<00:00, 8.68it/s, rgb=5.94, depth=14.25, rgbd=8.76]
Stats Ep 68: rgb: 10.7577 | rgbd: 7.3761 | depth: 18.9032
Ep 69/100: 100%|█████████████████████████████████| 2003/2003 [03:51<00:00, 8.64it/s, rgb=5.17, depth=18.33, rgbd=11.87]
Stats Ep 69: rgb: 11.3961 | rgbd: 9.0704 | depth: 18.1821
Ep 70/100: 100%|███████████████████████████████████| 2003/2003 [03:51<00:00, 8.66it/s, rgb=4.86, depth=9.76, rgbd=8.22]
Stats Ep 70: rgb: 10.8916 | rgbd: 8.2816 | depth: 20.1999
Ep 71/100: 100%|██████████████████████████████████| 2003/2003 [03:33<00:00, 9.38it/s, rgb=3.80, depth=24.90, rgbd=5.73]
Stats Ep 71: rgb: 11.2070 | rgbd: 8.4263 | depth: 20.6307
Ep 72/100: 100%|███████████████████████████████████| 2003/2003 [03:42<00:00, 9.02it/s, rgb=6.06, depth=7.28, rgbd=2.69]
Stats Ep 72: rgb: 11.7598 | rgbd: 9.4798 | depth: 19.7261
Ep 73/100: 100%|█████████████████████████████████| 2003/2003 [03:59<00:00, 8.37it/s, rgb=12.73, depth=18.86, rgbd=5.67]
Stats Ep 73: rgb: 10.7669 | rgbd: 7.8037 | depth: 20.3094
Ep 74/100: 100%|█████████████████████████████████| 2003/2003 [03:58<00:00, 8.41it/s, rgb=10.58, depth=14.65, rgbd=4.81]
Stats Ep 74: rgb: 10.6954 | rgbd: 8.6684 | depth: 18.1413
Ep 75/100: 100%|████████████████████████████████| 2003/2003 [03:24<00:00, 9.78it/s, rgb=26.53, depth=31.42, rgbd=88.70]
Stats Ep 75: rgb: 11.5845 | rgbd: 8.1621 | depth: 21.4552
Ep 76/100: 100%|█████████████████████████████████| 2003/2003 [03:18<00:00, 10.12it/s, rgb=7.62, depth=10.48, rgbd=12.33]
Stats Ep 76: rgb: 10.4904 | rgbd: 7.3135 | depth: 18.7806
Ep 77/100: 100%|████████████████████████████████| 2003/2003 [04:00<00:00, 8.32it/s, rgb=23.31, depth=50.30, rgbd=24.11]
Stats Ep 77: rgb: 11.3327 | rgbd: 8.2738 | depth: 20.7111
Ep 78/100: 100%|██████████████████████████████████| 2003/2003 [03:56<00:00, 8.47it/s, rgb=3.58, depth=14.82, rgbd=6.78]
Stats Ep 78: rgb: 10.5376 | rgbd: 7.7449 | depth: 17.8819
Ep 79/100: 100%|██████████████████████████████████| 2003/2003 [03:54<00:00, 8.53it/s, rgb=3.45, depth=14.50, rgbd=7.03]
Stats Ep 79: rgb: 11.3937 | rgbd: 8.5499 | depth: 19.8097
Ep 80/100: 100%|██████████████████████████████████| 2003/2003 [04:20<00:00, 7.68it/s, rgb=1.91, depth=14.95, rgbd=6.46]
Stats Ep 80: rgb: 10.9188 | rgbd: 7.5601 | depth: 20.2093
Ep 81/100: 100%|███████████████████████████████████| 2003/2003 [03:36<00:00, 9.27it/s, rgb=6.82, depth=9.80, rgbd=6.13]
Stats Ep 81: rgb: 10.5725 | rgbd: 7.1534 | depth: 19.3184
Ep 82/100: 100%|██████████████████████████████████| 2003/2003 [04:27<00:00, 7.50it/s, rgb=9.58, depth=28.38, rgbd=7.25]
Stats Ep 82: rgb: 11.1110 | rgbd: 7.1721 | depth: 21.3320
Ep 83/100: 100%|███████████████████████████████████| 2003/2003 [03:50<00:00, 8.69it/s, rgb=4.11, depth=8.64, rgbd=2.05]
Stats Ep 83: rgb: 10.8538 | rgbd: 7.4928 | depth: 20.2985
Ep 84/100: 100%|███████████████████████████████████| 2003/2003 [03:51<00:00, 8.67it/s, rgb=4.99, depth=3.75, rgbd=4.56]
Stats Ep 84: rgb: 10.8706 | rgbd: 7.2706 | depth: 19.2181
Ep 85/100: 100%|█████████████████████████████████| 2003/2003 [03:45<00:00, 8.88it/s, rgb=11.62, depth=10.85, rgbd=7.61]
Stats Ep 85: rgb: 11.5781 | rgbd: 7.3482 | depth: 19.9551
Ep 86/100: 100%|██████████████████████████████████| 2003/2003 [03:51<00:00, 8.67it/s, rgb=7.63, depth=20.37, rgbd=8.34]
Stats Ep 86: rgb: 12.1006 | rgbd: 7.6266 | depth: 20.5909
Ep 87/100: 100%|███████████████████████████████████| 2003/2003 [03:35<00:00, 9.28it/s, rgb=4.29, depth=6.06, rgbd=5.83]
Stats Ep 87: rgb: 11.4818 | rgbd: 8.5980 | depth: 20.3973
Ep 88/100: 100%|█████████████████████████████████| 2003/2003 [03:46<00:00, 8.85it/s, rgb=5.47, depth=23.82, rgbd=11.21]
Stats Ep 88: rgb: 11.3592 | rgbd: 8.3917 | depth: 21.0186
Ep 89/100: 100%|███████████████████████████████████| 2003/2003 [03:35<00:00, 9.31it/s, rgb=3.40, depth=5.67, rgbd=4.33]
Stats Ep 89: rgb: 10.9975 | rgbd: 7.7200 | depth: 19.8255
Ep 90/100: 100%|████████████████████████████████| 2003/2003 [03:33<00:00, 9.40it/s, rgb=11.59, depth=21.37, rgbd=12.20]
Stats Ep 90: rgb: 11.6379 | rgbd: 8.1928 | depth: 21.4311
Ep 91/100: 100%|███████████████████████████████████| 2003/2003 [03:43<00:00, 8.96it/s, rgb=3.98, depth=6.28, rgbd=3.13]
Stats Ep 91: rgb: 10.8083 | rgbd: 7.0418 | depth: 19.9876
Ep 92/100: 100%|████████████████████████████████| 2003/2003 [04:22<00:00, 7.64it/s, rgb=17.88, depth=31.89, rgbd=13.72]
Stats Ep 92: rgb: 10.3190 | rgbd: 7.5683 | depth: 19.3219
Ep 93/100: 100%|██████████████████████████████████| 2003/2003 [03:26<00:00, 9.69it/s, rgb=4.58, depth=24.49, rgbd=4.37]
Stats Ep 93: rgb: 10.2634 | rgbd: 6.6631 | depth: 19.4197
Ep 94/100: 100%|██████████████████████████████████| 2003/2003 [04:04<00:00, 8.19it/s, rgb=5.53, depth=7.15, rgbd=14.48]
Stats Ep 94: rgb: 11.2520 | rgbd: 8.3678 | depth: 20.4317
Ep 95/100: 100%|███████████████████████████████████| 2003/2003 [03:37<00:00, 9.19it/s, rgb=4.93, depth=7.55, rgbd=3.45]
Stats Ep 95: rgb: 11.3476 | rgbd: 7.3319 | depth: 20.8723
Ep 96/100: 100%|██████████████████████████████████| 2003/2003 [03:28<00:00, 9.59it/s, rgb=6.87, depth=11.18, rgbd=7.36]
Stats Ep 96: rgb: 11.5028 | rgbd: 8.4857 | depth: 21.5285
Ep 97/100: 100%|██████████████████████████████████| 2003/2003 [03:45<00:00, 8.86it/s, rgb=5.73, depth=11.61, rgbd=5.23]
Stats Ep 97: rgb: 11.0380 | rgbd: 7.0677 | depth: 18.7722
Ep 98/100: 100%|███████████████████████████████████| 2003/2003 [03:29<00:00, 9.56it/s, rgb=7.69, depth=8.28, rgbd=7.22]
Stats Ep 98: rgb: 11.1428 | rgbd: 7.2706 | depth: 19.1217
Ep 99/100: 100%|██████████████████████████████████| 2003/2003 [04:28<00:00, 7.45it/s, rgb=2.80, depth=16.02, rgbd=2.54]
Stats Ep 99: rgb: 11.5821 | rgbd: 7.5639 | depth: 20.6576
Ep 100/100: 100%|██████████████████████████████████| 2003/2003 [05:29<00:00, 6.08it/s, rgb=6.42, depth=7.36, rgbd=8.87]
Stats Ep 100: rgb: 11.0728 | rgbd: 6.8794 | depth: 20.3297
=== Final Results (Best MSE) ===
Mode rgb: 8.5831
Mode depth: 15.4920
Mode rgbd: 6.6631
This task demonstrates ISAC Beam Prediction using an expanded 256-DFT Codebook. It includes two steps: label generation and model training.
Step 1: Label Generation
Use the multipath components to calculate the optimal beam index for a 16x16 (256) UPA codebook.
Show Code
import numpy as np
import os
import glob
import pandas as pd
# --- Configuration ---
FC = 4.9e9 # Carrier frequency 4.9 GHz
C = 299792458 # Speed of light
WAVELENGTH = C / FC
ANTENNA_SPACING = WAVELENGTH / 2 # Half-wavelength spacing
# BS Antenna Config (8x8 UPA)
Nx, Ny = 8, 8
NUM_ANTENNAS = Nx * Ny
# Beam Config (Oversampled)
Mx, My = 16, 16 # 16x16 = 256 beams
NUM_BEAMS = Mx * My
def get_upa_steering_vector(theta, phi, N_x, N_y):
"""
Generate UPA steering vector.
theta: Elevation angle (0-180, 0 is zenith)
phi: Azimuth angle (-180-180)
"""
# Radian input
theta_rad = theta
phi_rad = phi
# Antenna index grid
# Array in XY plane
m = np.arange(N_x)
n = np.arange(N_y)
# Spatial frequencies u, v
# u: x-axis phase change
# v: y-axis phase change
u = np.sin(theta_rad) * np.cos(phi_rad)
v = np.sin(theta_rad) * np.sin(phi_rad)
# Generate steering vectors
# a_x shape (Nx, 1), a_y shape (Ny, 1)
a_x = np.exp(1j * 2 * np.pi * 0.5 * m * u)
a_y = np.exp(1j * 2 * np.pi * 0.5 * n * v)
# Kronecker product for full array response
# output shape: (Nx*Ny, )
steering_vector = np.kron(a_x, a_y)
return steering_vector
def build_channel_from_multipath(npz_path):
"""
Read multipath data from .npz and synthesize MISO channel vector h.
"""
data = np.load(npz_path)
# 1. Extract multipath components
# Flatten to avoid shape mismatch
# Complex gains
alphas = (data['a_real'] + 1j * data['a_imag']).flatten()
# Angles (AoD for BS)
thetas_t = data['theta_t'].flatten()
phis_t = data['phi_t'].flatten()
# Delays
taus = data['tau'].flatten()
# 2. Init channel vector h (64,)
h = np.zeros(NUM_ANTENNAS, dtype=complex)
# 3. Sum of Paths
# Assume Rx steering vector is 1 (MISO)
num_paths = len(alphas)
for i in range(num_paths):
# Phase shift (Delay Term)
phase_shift = np.exp(-1j * 2 * np.pi * FC * taus[i])
# Path coefficient
path_coeff = alphas[i] * phase_shift
# Tx Steering Vector
a_t = get_upa_steering_vector(thetas_t[i], phis_t[i], Nx, Ny)
# Accumulate
h += path_coeff * a_t
return h
def create_oversampled_dft_codebook(N, M):
"""
Create Oversampled DFT codebook.
N: Number of antennas
M: Number of beams
Returns matrix of shape (N, M)
"""
n = np.arange(N)
k = np.arange(M)
# DFT vectors: exp(j * 2 * pi * n * k / M)
# Normalized by sqrt(N) to maintain unit power (per antenna element scaling)
# W[n, k]
W = np.exp(1j * 2 * np.pi * np.outer(n, k) / M) / np.sqrt(N)
return W
def create_codebook(N_x, N_y, M_x, M_y):
"""
Create 2D Oversampled DFT codebook.
Codebook size = (Nx * Ny, Mx * My)
"""
# X-dim DFT matrix (Nx antennas, Mx beams)
Fx = create_oversampled_dft_codebook(N_x, M_x)
# Y-dim DFT matrix (Ny antennas, My beams)
Fy = create_oversampled_dft_codebook(N_y, M_y)
# 2D-DFT Codebook via Kronecker product
# Codebook shape: (Nx*Ny, Mx*My)
# Each column is a beam codeword
Codebook = np.kron(Fx, Fy)
return Codebook
def get_optimal_beam_label(h, codebook):
"""
Calculate optimal beam index.
h: Channel vector (Num_Antennas, )
codebook: Codebook matrix (Num_Antennas, Num_Beams)
"""
# 1. Beamforming: Calculate signal strength
# Transpose conjugate
# received_signals shape: (Num_Beams, )
received_signals = codebook.conj().T @ h
# 2. Calculate power
powers = np.abs(received_signals) ** 2
# 3. Find index of max power
best_beam_idx = np.argmax(powers)
return best_beam_idx, powers
# --- Batch Processing ---
if __name__ == "__main__":
# Paths
INPUT_DIR = r"D:\multi_path_npz"
# Save results to script dir
OUTPUT_CSV = os.path.join(os.path.dirname(os.path.abspath(__file__)), "beam_labels.csv")
# 1. Generate Codebook
print(f"Generating Oversampled DFT Codebook ({Mx}x{My} = {NUM_BEAMS} beams)...")
try:
W = create_codebook(Nx, Ny, Mx, My)
print(f"Codebook shape: {W.shape}")
# 2. Get all .npz files
pattern = os.path.join(INPUT_DIR, "csi_*.npz")
files = sorted(glob.glob(pattern))
print(f"Found {len(files)} data files.")
results = []
# 3. Batch processing
print("Starting batch processing...")
for idx, file_path in enumerate(files):
try:
# Parse filename
filename = os.path.basename(file_path)
# Generate channel
h_channel = build_channel_from_multipath(file_path)
# Get label
label, power_vectors = get_optimal_beam_label(h_channel, W)
max_power = power_vectors[label]
results.append({
'filename': filename,
'beam_index': label,
'received_power': max_power
})
# Log progress every 100 files
if (idx + 1) % 100 == 0:
print(f"Processed {idx + 1}/{len(files)} files...")
except Exception as e:
print(f"Error processing {filename}: {e}")
# 4. Save results
if results:
df = pd.DataFrame(results)
df.to_csv(OUTPUT_CSV, index=False)
print(f"\nFinished! Results saved to: {OUTPUT_CSV}")
print("First 5 rows:")
print(df.head())
else:
print("No results generated. Please check the paths.")
except Exception as e:
print(f"An error occurred: {e}")
Step 2: Ablation Study (RGB vs Depth vs Fusion)
Design an ISAC_MobileNet to predict the optimal beam index (0-255). The script compares performance using RGB-only, Depth-only, and Fusion inputs.
Show Code
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
import numpy as np
# --- 1. Network Model Design (Based on MobileNetV2) ---
class ISAC_MobileNet(nn.Module):
def __init__(self, num_classes=256, in_channels=4, pretrained=False):
super(ISAC_MobileNet, self).__init__()
# Load MobileNetV2
# pretrained=False: Random initialization, no ImageNet weights
self.backbone = models.mobilenet_v2(weights=None)
# --- Modify Input Layer ---
# Modify first layer to accept custom channels (RGB+D=4, RGB=3, Depth=1)
original_first_layer = self.backbone.features[0][0]
# Create new conv layer with specified input channels
new_first_layer = nn.Conv2d(
in_channels=in_channels,
out_channels=original_first_layer.out_channels,
kernel_size=original_first_layer.kernel_size,
stride=original_first_layer.stride,
padding=original_first_layer.padding,
bias=False
)
# Replace the first layer
self.backbone.features[0][0] = new_first_layer
# --- Modify Output Layer (Classifier) ---
# MobileNetV2 classifier is classifier[1]
in_features = self.backbone.classifier[1].in_features
self.backbone.classifier[1] = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.backbone(x)
# --- 2. Data Loader Design ---
class RGBDBeamDataset(Dataset):
def __init__(self, csv_file, rgb_dir, depth_dir, transform=None):
"""
csv_file: path to beam_labels.csv generated by generateH.py
rgb_dir: path to RGB image directory
depth_dir: path to depth map directory
"""
self.df = pd.read_csv(csv_file)
self.rgb_dir = rgb_dir
self.depth_dir = depth_dir
# Default preprocessing
if transform is None:
self.transform = transforms.Compose([
transforms.Resize((512, 512)), # Resize to 512x512
transforms.ToTensor(), # ToTensor (normalizes to [0,1])
])
else:
self.transform = transform
# RGB Normalization (ImageNet stats)
self.rgb_normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# 1. Get filename and label
row = self.df.iloc[idx]
npz_name = row['filename'] # e.g., csi_000000.npz
label = int(row['beam_index'])
# 2. Construct filenames
# e.g., csi_000000.npz -> 000000
file_id = npz_name.replace('csi_', '').replace('.npz', '')
# Modify filename format
rgb_name = f"img_{file_id}.png"
depth_name = f"depth_{file_id}.npz" # Depth in .npz format
rgb_path = os.path.join(self.rgb_dir, rgb_name)
depth_path = os.path.join(self.depth_dir, depth_name)
# 3. Load images
try:
# Load RGB (PIL default)
rgb_img = Image.open(rgb_path).convert('RGB')
# Load Depth (.npz)
with np.load(depth_path) as data:
# Assume single array or unknown key
key = data.files[0]
depth_arr = data[key] # Read as numpy array
# Assume float depth
depth_img = Image.fromarray(depth_arr, mode='F') # mode='F' (32-bit float)
except Exception as e:
print(f"Error reading image: {rgb_path} or {depth_path} : {e}")
# dummy data on error
return torch.zeros(4, 512, 512), label
# 4. Preprocessing
# Ensure consistent resize for RGB and Depth
rgb_tensor = self.transform(rgb_img) # (3, 512, 512), range [0, 1]
depth_tensor = self.transform(depth_img) # (1, 512, 512), range [0, 1]
# 5. Normalization
# Normalize RGB to have zero mean and unit variance per image (Instance Normalization)
rgb_mean = rgb_tensor.mean()
rgb_std = rgb_tensor.std()
if rgb_std > 1e-6:
rgb_tensor = (rgb_tensor - rgb_mean) / rgb_std
else:
rgb_tensor = rgb_tensor - rgb_mean
# --- Depth Normalization ---
d_mean = depth_tensor.mean()
d_std = depth_tensor.std()
if d_std > 1e-6:
depth_tensor = (depth_tensor - d_mean) / d_std
else:
depth_tensor = depth_tensor - d_mean
# 6. Early Fusion
# Concat channel dim -> (4, 512, 512)
input_tensor = torch.cat([rgb_tensor, depth_tensor], dim=0)
return input_tensor, label
# --- 3. Training/Testing Example ---
if __name__ == "__main__":
from sklearn.model_selection import train_test_split
from tqdm import tqdm
# --- 1. Settings ---
csv_path = r"d:\exp\beam_labels.csv"
rgb_path = r"E:\Datasets1\San Francisco\Scene 1\roof_bs_01\cam"
depth_path = r"E:\Datasets1\San Francisco\Scene 1\roof_bs_01\depth"
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 0.001
# Auto device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- Optimization: Enable cuDNN benchmark ---
if device.type == 'cuda':
torch.backends.cudnn.benchmark = True
# --- 2. Data Splitting ---
# Load full data
full_df = pd.read_csv(csv_path)
# Split train (80%) and test (20%)
train_df, test_df = train_test_split(full_df, test_size=0.2, random_state=42, shuffle=True)
# Save temp files for Dataset
train_df.to_csv('temp_train.csv', index=False)
test_df.to_csv('temp_test.csv', index=False)
train_dataset = RGBDBeamDataset('temp_train.csv', rgb_path, depth_path)
test_dataset = RGBDBeamDataset('temp_test.csv', rgb_path, depth_path)
# --- Optimization: Multi-worker data loading ---
# Windows typically handles 4 workers reasonably well
num_workers = 4 if os.name == 'nt' else 8
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=num_workers, pin_memory=True)
# --- 3. Model, Loss, Optimizer ---
# Instantiate 3 models for ablation study
print("Initializing models for Ablation Study...")
# 1. RGB+Depth (4 channels)
model_all = ISAC_MobileNet(num_classes=256, in_channels=4, pretrained=False).to(device)
opt_all = torch.optim.Adam(model_all.parameters(), lr=LEARNING_RATE)
# 2. RGB Only (3 channels)
model_rgb = ISAC_MobileNet(num_classes=256, in_channels=3, pretrained=False).to(device)
opt_rgb = torch.optim.Adam(model_rgb.parameters(), lr=LEARNING_RATE)
# 3. Depth Only (1 channel)
model_depth = ISAC_MobileNet(num_classes=256, in_channels=1, pretrained=False).to(device)
opt_depth = torch.optim.Adam(model_depth.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
# --- Optimization: Mixed Precision Scaler ---
scaler = torch.amp.GradScaler('cuda')
# --- 4. Training Loop ---
print("\nStarting Parallel Training for 3 Models...")
best_acc_all = 0.0
best_acc_rgb = 0.0
best_acc_depth = 0.0
for epoch in range(EPOCHS):
model_all.train()
model_rgb.train()
model_depth.train()
# Statistics for 3 models
stats = {
'all': {'loss': 0.0, 'correct': 0},
'rgb': {'loss': 0.0, 'correct': 0},
'depth': {'loss': 0.0, 'correct': 0},
'total': 0
}
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
for inputs, labels in loop:
# Prepare inputs on GPU (non_blocking for speed)
inputs = inputs.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
# Split inputs
# inputs is [B, 4, H, W] -> RGB=[B, :3, ...], Depth=[B, 3:, ...]
input_rgb = inputs[:, :3, :, :]
input_depth = inputs[:, 3:, :, :]
# --- Train Model 1 (RGB+D) ---
opt_all.zero_grad()
with torch.amp.autocast('cuda'):
out_all = model_all(inputs)
loss_all = criterion(out_all, labels)
scaler.scale(loss_all).backward()
scaler.step(opt_all)
# --- Train Model 2 (RGB) ---
opt_rgb.zero_grad()
with torch.amp.autocast('cuda'):
out_rgb = model_rgb(input_rgb)
loss_rgb = criterion(out_rgb, labels)
scaler.scale(loss_rgb).backward()
scaler.step(opt_rgb)
# --- Train Model 3 (Depth) ---
opt_depth.zero_grad()
with torch.amp.autocast('cuda'):
out_depth = model_depth(input_depth)
loss_depth = criterion(out_depth, labels)
scaler.scale(loss_depth).backward()
scaler.step(opt_depth)
# Update scaler
scaler.update()
# --- Stats Update ---
stats['total'] += labels.size(0)
# RGB+D
stats['all']['loss'] += loss_all.item()
_, pred_all = torch.max(out_all.data, 1)
stats['all']['correct'] += (pred_all == labels).sum().item()
# RGB
stats['rgb']['loss'] += loss_rgb.item()
_, pred_rgb = torch.max(out_rgb.data, 1)
stats['rgb']['correct'] += (pred_rgb == labels).sum().item()
# Depth
stats['depth']['loss'] += loss_depth.item()
_, pred_depth = torch.max(out_depth.data, 1)
stats['depth']['correct'] += (pred_depth == labels).sum().item()
loop.set_postfix({
'L_All': f"{loss_all.item():.2f}",
'L_RGB': f"{loss_rgb.item():.2f}",
'L_Dep': f"{loss_depth.item():.2f}"
})
# Calculate epoch metrics
num_batches = len(train_loader)
total_samples = stats['total']
train_res = {}
for k in ['all', 'rgb', 'depth']:
train_res[k] = {
'loss': stats[k]['loss'] / num_batches,
'acc': stats[k]['correct'] / total_samples
}
# --- 5. Testing/Validation ---
model_all.eval()
model_rgb.eval()
model_depth.eval()
test_stats = {
'all': {'correct': 0, 'top5': 0},
'rgb': {'correct': 0, 'top5': 0},
'depth': {'correct': 0, 'top5': 0},
'total': 0
}
with torch.no_grad():
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
input_rgb = inputs[:, :3, :, :]
input_depth = inputs[:, 3:, :, :]
test_stats['total'] += labels.size(0)
# Helper function for eval
def eval_batch(model, x, key):
out = model(x)
# Top 1
_, pred = torch.max(out.data, 1)
test_stats[key]['correct'] += (pred == labels).sum().item()
# Top 5
_, top5 = out.topk(5, 1, True, True)
top5 = top5.t()
correct_row = top5.eq(labels.view(1, -1).expand_as(top5))
test_stats[key]['top5'] += correct_row.reshape(-1).float().sum().item()
eval_batch(model_all, inputs, 'all')
eval_batch(model_rgb, input_rgb, 'rgb')
eval_batch(model_depth, input_depth, 'depth')
# Print Results
print(f"\nEpoch {epoch+1} Results:")
print(f" [RGB+D] Train Loss: {train_res['all']['loss']:.4f}, Acc: {train_res['all']['acc']:.4f} | "
f"Test Acc: {test_stats['all']['correct']/test_stats['total']:.4f}, Top5: {test_stats['all']['top5']/test_stats['total']:.4f}")
print(f" [RGB ] Train Loss: {train_res['rgb']['loss']:.4f}, Acc: {train_res['rgb']['acc']:.4f} | "
f"Test Acc: {test_stats['rgb']['correct']/test_stats['total']:.4f}, Top5: {test_stats['rgb']['top5']/test_stats['total']:.4f}")
print(f" [Depth] Train Loss: {train_res['depth']['loss']:.4f}, Acc: {train_res['depth']['acc']:.4f} | "
f"Test Acc: {test_stats['depth']['correct']/test_stats['total']:.4f}, Top5: {test_stats['depth']['top5']/test_stats['total']:.4f}")
# Save Checkpoints
current_acc_all = test_stats['all']['correct']/test_stats['total']
current_acc_rgb = test_stats['rgb']['correct']/test_stats['total']
current_acc_depth = test_stats['depth']['correct']/test_stats['total']
if current_acc_all > best_acc_all:
best_acc_all = current_acc_all
torch.save(model_all.state_dict(), "best_model_rgbd.pth")
if current_acc_rgb > best_acc_rgb:
best_acc_rgb = current_acc_rgb
torch.save(model_rgb.state_dict(), "best_model_rgb.pth")
if current_acc_depth > best_acc_depth:
best_acc_depth = current_acc_depth
torch.save(model_depth.state_dict(), "best_model_depth.pth")
print(f"\nTraining Finished!")
print(f"Best Acc - RGB+D: {best_acc_all:.4f}, RGB: {best_acc_rgb:.4f}, Depth: {best_acc_depth:.4f}")
# Clean temp files
if os.path.exists('temp_train.csv'): os.remove('temp_train.csv')
if os.path.exists('temp_test.csv'): os.remove('temp_test.csv')
Step 3: Experimental Results
After a parallel training ablation study over 10 epochs using the San Francisco Urban Block 1 dataset, RGB+D achieves the best Top-1 accuracy (95.86%), followed by RGB (93.51%), while Depth alone reaches 85.13%. Notably, Depth converges much faster in the early epochs (Epoch 1–3), indicating that geometric structure provides a strong inductive bias for low-altitude BS-to-UAV beam prediction. With sufficient training, RGB catches up due to semantic/texture cues, and fusion benefits from complementary geometry + appearance, yielding the highest final performance.
Show Full Training Logs
Epoch 1 Results:
[RGB+D] Train Loss: 2.5777, Acc: 0.2123 | Test Acc: 0.2155, Top5: 0.6824
[RGB ] Train Loss: 2.5822, Acc: 0.2165 | Test Acc: 0.2347, Top5: 0.6648
[Depth] Train Loss: 1.5537, Acc: 0.4881 | Test Acc: 0.6426, Top5: 0.9228
Epoch 2/10: 100%|██████████| 2003/2003 [1:03:07<00:00, 1.89s/it, L_All=2.64, L_RGB=2.41, L_Dep=1.39]
Epoch 2 Results:
[RGB+D] Train Loss: 2.3542, Acc: 0.2445 | Test Acc: 0.2637, Top5: 0.7347
[RGB ] Train Loss: 2.4067, Acc: 0.2342 | Test Acc: 0.2328, Top5: 0.6840
[Depth] Train Loss: 0.8600, Acc: 0.7286 | Test Acc: 0.7823, Top5: 0.9478
Epoch 3/10: 100%|██████████| 2003/2003 [1:02:45<00:00, 1.88s/it, L_All=2.08, L_RGB=2.06, L_Dep=0.24]
Epoch 3 Results:
[RGB+D] Train Loss: 2.0072, Acc: 0.3285 | Test Acc: 0.3759, Top5: 0.8941
[RGB ] Train Loss: 2.2108, Acc: 0.2781 | Test Acc: 0.3363, Top5: 0.8380
[Depth] Train Loss: 0.6372, Acc: 0.8013 | Test Acc: 0.8149, Top5: 0.9502
Epoch 4/10: 100%|██████████| 2003/2003 [1:02:39<00:00, 1.88s/it, L_All=2.41, L_RGB=2.89, L_Dep=1.78]
Epoch 4 Results:
[RGB+D] Train Loss: 1.6091, Acc: 0.4114 | Test Acc: 0.4353, Top5: 0.9564
[RGB ] Train Loss: 1.7301, Acc: 0.3860 | Test Acc: 0.4261, Top5: 0.9417
[Depth] Train Loss: 0.5715, Acc: 0.8216 | Test Acc: 0.7408, Top5: 0.9296
Epoch 5/10: 100%|██████████| 2003/2003 [1:02:38<00:00, 1.88s/it, L_All=2.22, L_RGB=1.45, L_Dep=0.97]
Epoch 5 Results:
[RGB+D] Train Loss: 1.3777, Acc: 0.4518 | Test Acc: 0.4426, Top5: 0.9713
[RGB ] Train Loss: 1.4058, Acc: 0.4487 | Test Acc: 0.4579, Top5: 0.9687
[Depth] Train Loss: 0.5346, Acc: 0.8299 | Test Acc: 0.8271, Top5: 0.9576
Epoch 6/10: 100%|██████████| 2003/2003 [1:01:59<00:00, 1.86s/it, L_All=1.86, L_RGB=1.41, L_Dep=0.43]
Epoch 6 Results:
[RGB+D] Train Loss: 1.2198, Acc: 0.4971 | Test Acc: 0.5335, Top5: 0.9737
[RGB ] Train Loss: 1.2346, Acc: 0.4880 | Test Acc: 0.5217, Top5: 0.9764
[Depth] Train Loss: 0.5025, Acc: 0.8393 | Test Acc: 0.8347, Top5: 0.9571
Epoch 7/10: 100%|██████████| 2003/2003 [1:01:52<00:00, 1.85s/it, L_All=0.34, L_RGB=1.68, L_Dep=0.74]
Epoch 7 Results:
[RGB+D] Train Loss: 0.8043, Acc: 0.7067 | Test Acc: 0.8830, Top5: 0.9948
[RGB ] Train Loss: 1.0366, Acc: 0.5741 | Test Acc: 0.6318, Top5: 0.9879
[Depth] Train Loss: 0.4821, Acc: 0.8441 | Test Acc: 0.8385, Top5: 0.9588
Epoch 8/10: 100%|██████████| 2003/2003 [1:01:59<00:00, 1.86s/it, L_All=0.39, L_RGB=1.26, L_Dep=0.61]
Epoch 8 Results:
[RGB+D] Train Loss: 0.3304, Acc: 0.9049 | Test Acc: 0.9268, Top5: 0.9971
[RGB ] Train Loss: 0.7896, Acc: 0.6944 | Test Acc: 0.7487, Top5: 0.9920
[Depth] Train Loss: 0.4563, Acc: 0.8511 | Test Acc: 0.8280, Top5: 0.9577
Epoch 9/10: 100%|██████████| 2003/2003 [1:02:06<00:00, 1.86s/it, L_All=0.12, L_RGB=0.46, L_Dep=0.05]
Epoch 9 Results:
[RGB+D] Train Loss: 0.2375, Acc: 0.9354 | Test Acc: 0.9576, Top5: 0.9970
[RGB ] Train Loss: 0.5006, Acc: 0.8237 | Test Acc: 0.9180, Top5: 0.9963
[Depth] Train Loss: 0.4543, Acc: 0.8491 | Test Acc: 0.8513, Top5: 0.9596
Epoch 10/10: 100%|██████████| 2003/2003 [1:01:05<00:00, 1.83s/it, L_All=0.24, L_RGB=0.32, L_Dep=0.02]
Epoch 10 Results:
[RGB+D] Train Loss: 0.2055, Acc: 0.9447 | Test Acc: 0.9586, Top5: 0.9986
[RGB ] Train Loss: 0.2668, Acc: 0.9246 | Test Acc: 0.9351, Top5: 0.9971
[Depth] Train Loss: 0.4429, Acc: 0.8542 | Test Acc: 0.8463, Top5: 0.9591
Training Finished!
Best Acc - RGB+D: 0.9586, RGB: 0.9351, Depth: 0.8513